diff --git a/.gitignore b/.gitignore index 88d7b56181be7607e315b5f01990ed8951336dc9..155e785b01beb809a13c45c40d96f04f4dd6343b 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,5 @@ log/ spark-tests.log streaming-tests.log dependency-reduced-pom.xml +.ensime +.ensime_lucene diff --git a/bagel/src/main/scala/spark/bagel/Bagel.scala b/bagel/src/main/scala/spark/bagel/Bagel.scala index 996ca2a8771e5c5d30b88ff0b97f0a2e3c85b955..094e57dacb7cc313735c3f80556248c8d58f2e20 100644 --- a/bagel/src/main/scala/spark/bagel/Bagel.scala +++ b/bagel/src/main/scala/spark/bagel/Bagel.scala @@ -6,19 +6,19 @@ import spark.SparkContext._ import scala.collection.mutable.ArrayBuffer object Bagel extends Logging { - def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, - C : Manifest, A : Manifest]( + def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, + C: Manifest, A: Manifest]( sc: SparkContext, vertices: RDD[(K, V)], messages: RDD[(K, M)], combiner: Combiner[M, C], aggregator: Option[Aggregator[V, A]], partitioner: Partitioner, - numSplits: Int + numPartitions: Int )( compute: (V, Option[C], Option[A], Int) => (V, Array[M]) ): RDD[(K, V)] = { - val splits = if (numSplits != 0) numSplits else sc.defaultParallelism + val splits = if (numPartitions != 0) numPartitions else sc.defaultParallelism var superstep = 0 var verts = vertices @@ -50,49 +50,47 @@ object Bagel extends Logging { verts } - def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, - C : Manifest]( + def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( sc: SparkContext, vertices: RDD[(K, V)], messages: RDD[(K, M)], combiner: Combiner[M, C], partitioner: Partitioner, - numSplits: Int + numPartitions: Int )( compute: (V, Option[C], Int) => (V, Array[M]) ): RDD[(K, V)] = { run[K, V, M, C, Nothing]( - sc, vertices, messages, combiner, None, partitioner, numSplits)( + sc, vertices, messages, combiner, None, partitioner, numPartitions)( addAggregatorArg[K, V, M, C](compute)) } - def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, - C : Manifest]( + def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( sc: SparkContext, vertices: RDD[(K, V)], messages: RDD[(K, M)], combiner: Combiner[M, C], - numSplits: Int + numPartitions: Int )( compute: (V, Option[C], Int) => (V, Array[M]) ): RDD[(K, V)] = { - val part = new HashPartitioner(numSplits) + val part = new HashPartitioner(numPartitions) run[K, V, M, C, Nothing]( - sc, vertices, messages, combiner, None, part, numSplits)( + sc, vertices, messages, combiner, None, part, numPartitions)( addAggregatorArg[K, V, M, C](compute)) } - def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest]( + def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest]( sc: SparkContext, vertices: RDD[(K, V)], messages: RDD[(K, M)], - numSplits: Int + numPartitions: Int )( compute: (V, Option[Array[M]], Int) => (V, Array[M]) ): RDD[(K, V)] = { - val part = new HashPartitioner(numSplits) + val part = new HashPartitioner(numPartitions) run[K, V, M, Array[M], Nothing]( - sc, vertices, messages, new DefaultCombiner(), None, part, numSplits)( + sc, vertices, messages, new DefaultCombiner(), None, part, numPartitions)( addAggregatorArg[K, V, M, Array[M]](compute)) } @@ -100,7 +98,7 @@ object Bagel extends Logging { * Aggregates the given vertices using the given aggregator, if it * is specified. */ - private def agg[K, V <: Vertex, A : Manifest]( + private def agg[K, V <: Vertex, A: Manifest]( verts: RDD[(K, V)], aggregator: Option[Aggregator[V, A]] ): Option[A] = aggregator match { @@ -116,7 +114,7 @@ object Bagel extends Logging { * function. Returns the processed RDD, the number of messages * created, and the number of active vertices. */ - private def comp[K : Manifest, V <: Vertex, M <: Message[K], C]( + private def comp[K: Manifest, V <: Vertex, M <: Message[K], C]( sc: SparkContext, grouped: RDD[(K, (Seq[C], Seq[V]))], compute: (V, Option[C]) => (V, Array[M]) @@ -149,9 +147,7 @@ object Bagel extends Logging { * Converts a compute function that doesn't take an aggregator to * one that does, so it can be passed to Bagel.run. */ - private def addAggregatorArg[ - K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C - ]( + private def addAggregatorArg[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C]( compute: (V, Option[C], Int) => (V, Array[M]) ): (V, Option[C], Option[Nothing], Int) => (V, Array[M]) = { (vert: V, msgs: Option[C], aggregated: Option[Nothing], superstep: Int) => @@ -170,7 +166,7 @@ trait Aggregator[V, A] { def mergeAggregators(a: A, b: A): A } -class DefaultCombiner[M : Manifest] extends Combiner[M, Array[M]] with Serializable { +class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializable { def createCombiner(msg: M): Array[M] = Array(msg) def mergeMsg(combiner: Array[M], msg: M): Array[M] = diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala index 03843019c010cf54c7c0209967d17503078862b1..bc32663e0fde6e9bfd371178525894840f630a47 100644 --- a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala +++ b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala @@ -16,7 +16,7 @@ import scala.xml.{XML,NodeSeq} object WikipediaPageRank { def main(args: Array[String]) { if (args.length < 5) { - System.err.println("Usage: WikipediaPageRank <inputFile> <threshold> <numSplits> <host> <usePartitioner>") + System.err.println("Usage: WikipediaPageRank <inputFile> <threshold> <numPartitions> <host> <usePartitioner>") System.exit(-1) } @@ -25,7 +25,7 @@ object WikipediaPageRank { val inputFile = args(0) val threshold = args(1).toDouble - val numSplits = args(2).toInt + val numPartitions = args(2).toInt val host = args(3) val usePartitioner = args(4).toBoolean val sc = new SparkContext(host, "WikipediaPageRank") @@ -69,7 +69,7 @@ object WikipediaPageRank { val result = Bagel.run( sc, vertices, messages, combiner = new PRCombiner(), - numSplits = numSplits)( + numPartitions = numPartitions)( utils.computeWithCombiner(numVertices, epsilon)) // Print the result diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala index 06cc8c748b98293819925d6705f17355512d3ca4..9d9d80d809d0d58ac0d6b52f5c581646cf5323b8 100644 --- a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala +++ b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala @@ -88,7 +88,7 @@ object WikipediaPageRankStandalone { n: Long, partitioner: Partitioner, usePartitioner: Boolean, - numSplits: Int + numPartitions: Int ): RDD[(String, Double)] = { var ranks = links.mapValues { edges => defaultRank } for (i <- 1 to numIterations) { diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala index 3c2f9c4616fe206ca7ddd7d1e004d9d4b445279e..47829a431e871489b622ae0ea5e050163e0b1427 100644 --- a/bagel/src/test/scala/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/bagel/BagelSuite.scala @@ -1,10 +1,8 @@ package spark.bagel import org.scalatest.{FunSuite, Assertions, BeforeAndAfter} -import org.scalatest.prop.Checkers -import org.scalacheck.Arbitrary._ -import org.scalacheck.Gen -import org.scalacheck.Prop._ +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ import scala.collection.mutable.ArrayBuffer @@ -13,7 +11,7 @@ import spark._ class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable class TestMessage(val targetId: String) extends Message[String] with Serializable -class BagelSuite extends FunSuite with Assertions with BeforeAndAfter { +class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeouts { var sc: SparkContext = _ @@ -25,7 +23,7 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") } - + test("halting by voting") { sc = new SparkContext("local", "test") val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0)))) @@ -36,8 +34,9 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter { (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) } - for ((id, vert) <- result.collect) + for ((id, vert) <- result.collect) { assert(vert.age === numSupersteps) + } } test("halting by message silence") { @@ -57,7 +56,27 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter { } (new TestVertex(self.active, self.age + 1), msgsOut) } - for ((id, vert) <- result.collect) + for ((id, vert) <- result.collect) { assert(vert.age === numSupersteps) + } + } + + test("large number of iterations") { + // This tests whether jobs with a large number of iterations finish in a reasonable time, + // because non-memoized recursion in RDD or DAGScheduler used to cause them to hang + failAfter(10 seconds) { + sc = new SparkContext("local", "test") + val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) + val msgs = sc.parallelize(Array[(String, TestMessage)]()) + val numSupersteps = 50 + val result = + Bagel.run(sc, verts, msgs, sc.defaultParallelism) { + (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => + (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) + } + for ((id, vert) <- result.collect) { + assert(vert.age === numSupersteps) + } + } } } diff --git a/core/src/main/scala/spark/CacheManager.scala b/core/src/main/scala/spark/CacheManager.scala index 711435c3335fae341332118b6215cb87ae12d5c2..c7b379a3fbe2d8da307693ec12f20de00832e339 100644 --- a/core/src/main/scala/spark/CacheManager.scala +++ b/core/src/main/scala/spark/CacheManager.scala @@ -11,13 +11,13 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { private val loading = new HashSet[String] /** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */ - def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel) + def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext, storageLevel: StorageLevel) : Iterator[T] = { val key = "rdd_%d_%d".format(rdd.id, split.index) logInfo("Cache key is " + key) blockManager.get(key) match { case Some(cachedValues) => - // Split is in cache, so just return its values + // Partition is in cache, so just return its values logInfo("Found partition in cache!") return cachedValues.asInstanceOf[Iterator[T]] diff --git a/core/src/main/scala/spark/DoubleRDDFunctions.scala b/core/src/main/scala/spark/DoubleRDDFunctions.scala index b2a0e2b631e7aabcd8892a93aec2d0bbbbb28a35..178d31a73b9130fce8d96bc57ca0c1f82a4876a8 100644 --- a/core/src/main/scala/spark/DoubleRDDFunctions.scala +++ b/core/src/main/scala/spark/DoubleRDDFunctions.scala @@ -42,14 +42,14 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { /** (Experimental) Approximate operation to return the mean within a timeout. */ def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new MeanEvaluator(self.splits.size, confidence) + val evaluator = new MeanEvaluator(self.partitions.size, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } /** (Experimental) Approximate operation to return the sum within a timeout. */ def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new SumEvaluator(self.splits.size, confidence) + val evaluator = new SumEvaluator(self.partitions.size, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } } diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index cc3cca25713592fa009a0d70ac94b35118b9a422..e7408e4352abfc2e53e20f91039e186ff72c139f 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -23,6 +23,7 @@ import spark.partial.BoundedDouble import spark.partial.PartialResult import spark.rdd._ import spark.SparkContext._ +import spark.Partitioner._ /** * Extra functions available on RDDs of (key, value) pairs through an implicit conversion. @@ -62,7 +63,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( } val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) - if (mapSideCombine) { + if (self.partitioner == Some(partitioner)) { + self.mapPartitions(aggregator.combineValuesByKey(_), true) + } else if (mapSideCombine) { val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true) val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner) partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true) @@ -81,8 +84,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, - numSplits: Int): RDD[(K, C)] = { - combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numSplits)) + numPartitions: Int): RDD[(K, C)] = { + combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) } /** @@ -143,10 +146,10 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( /** * Merge the values for each key using an associative reduce function. This will also perform * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with numSplits splits. + * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. */ - def reduceByKey(func: (V, V) => V, numSplits: Int): RDD[(K, V)] = { - reduceByKey(new HashPartitioner(numSplits), func) + def reduceByKey(func: (V, V) => V, numPartitions: Int): RDD[(K, V)] = { + reduceByKey(new HashPartitioner(numPartitions), func) } /** @@ -164,10 +167,10 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( /** * Group the values for each key in the RDD into a single sequence. Hash-partitions the - * resulting RDD with into `numSplits` partitions. + * resulting RDD with into `numPartitions` partitions. */ - def groupByKey(numSplits: Int): RDD[(K, Seq[V])] = { - groupByKey(new HashPartitioner(numSplits)) + def groupByKey(numPartitions: Int): RDD[(K, Seq[V])] = { + groupByKey(new HashPartitioner(numPartitions)) } /** @@ -246,8 +249,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( } /** - * Simplified version of combineByKey that hash-partitions the resulting RDD using the default - * parallelism level. + * Simplified version of combineByKey that hash-partitions the resulting RDD using the + * existing partitioner/parallelism level. */ def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) : RDD[(K, C)] = { @@ -257,7 +260,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( /** * Merge the values for each key using an associative reduce function. This will also perform * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with the default parallelism level. + * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ + * parallelism level. */ def reduceByKey(func: (V, V) => V): RDD[(K, V)] = { reduceByKey(defaultPartitioner(self), func) @@ -265,7 +269,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( /** * Group the values for each key in the RDD into a single sequence. Hash-partitions the - * resulting RDD with the default parallelism level. + * resulting RDD with the existing partitioner/parallelism level. */ def groupByKey(): RDD[(K, Seq[V])] = { groupByKey(defaultPartitioner(self)) @@ -285,15 +289,15 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and * (k, v2) is in `other`. Performs a hash join across the cluster. */ - def join[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (V, W))] = { - join(other, new HashPartitioner(numSplits)) + def join[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, W))] = { + join(other, new HashPartitioner(numPartitions)) } /** * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output - * using the default level of parallelism. + * using the existing partitioner/parallelism level. */ def leftOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (V, Option[W]))] = { leftOuterJoin(other, defaultPartitioner(self, other)) @@ -303,17 +307,17 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output - * into `numSplits` partitions. + * into `numPartitions` partitions. */ - def leftOuterJoin[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (V, Option[W]))] = { - leftOuterJoin(other, new HashPartitioner(numSplits)) + def leftOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (V, Option[W]))] = { + leftOuterJoin(other, new HashPartitioner(numPartitions)) } /** * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting - * RDD using the default parallelism level. + * RDD using the existing partitioner/parallelism level. */ def rightOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], W))] = { rightOuterJoin(other, defaultPartitioner(self, other)) @@ -325,8 +329,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting * RDD into the given number of partitions. */ - def rightOuterJoin[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (Option[V], W))] = { - rightOuterJoin(other, new HashPartitioner(numSplits)) + def rightOuterJoin[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Option[V], W))] = { + rightOuterJoin(other, new HashPartitioner(numPartitions)) } /** @@ -361,7 +365,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( throw new SparkException("Default partitioner cannot partition array keys.") } val cg = new CoGroupedRDD[K]( - Seq(self.asInstanceOf[RDD[(_, _)]], other.asInstanceOf[RDD[(_, _)]]), + Seq(self.asInstanceOf[RDD[(K, _)]], other.asInstanceOf[RDD[(K, _)]]), partitioner) val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest) prfs.mapValues { @@ -380,9 +384,9 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( throw new SparkException("Default partitioner cannot partition array keys.") } val cg = new CoGroupedRDD[K]( - Seq(self.asInstanceOf[RDD[(_, _)]], - other1.asInstanceOf[RDD[(_, _)]], - other2.asInstanceOf[RDD[(_, _)]]), + Seq(self.asInstanceOf[RDD[(K, _)]], + other1.asInstanceOf[RDD[(K, _)]], + other2.asInstanceOf[RDD[(K, _)]]), partitioner) val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest) prfs.mapValues { @@ -412,17 +416,17 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. */ - def cogroup[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (Seq[V], Seq[W]))] = { - cogroup(other, new HashPartitioner(numSplits)) + def cogroup[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Seq[V], Seq[W]))] = { + cogroup(other, new HashPartitioner(numPartitions)) } /** * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a * tuple with the list of values for that key in `this`, `other1` and `other2`. */ - def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], numSplits: Int) + def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], numPartitions: Int) : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { - cogroup(other1, other2, new HashPartitioner(numSplits)) + cogroup(other1, other2, new HashPartitioner(numPartitions)) } /** Alias for cogroup. */ @@ -436,17 +440,6 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( cogroup(other1, other2, defaultPartitioner(self, other1, other2)) } - /** - * Choose a partitioner to use for a cogroup-like operation between a number of RDDs. If any of - * the RDDs already has a partitioner, choose that one, otherwise use a default HashPartitioner. - */ - def defaultPartitioner(rdds: RDD[_]*): Partitioner = { - for (r <- rdds if r.partitioner != None) { - return r.partitioner.get - } - return new HashPartitioner(self.context.defaultParallelism) - } - /** * Return the list of values in the RDD for key `key`. This operation is done efficiently if the * RDD has a known partitioner by only searching the partition that the key maps to. @@ -634,9 +627,9 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in * order of the keys). */ - def sortByKey(ascending: Boolean = true, numSplits: Int = self.splits.size): RDD[(K,V)] = { + def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[(K,V)] = { val shuffled = - new ShuffledRDD[K, V](self, new RangePartitioner(numSplits, self, ascending)) + new ShuffledRDD[K, V](self, new RangePartitioner(numPartitions, self, ascending)) shuffled.mapPartitions(iter => { val buf = iter.toArray if (ascending) { @@ -650,9 +643,9 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( private[spark] class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev) { - override def getSplits = firstParent[(K, V)].splits + override def getPartitions = firstParent[(K, V)].partitions override val partitioner = firstParent[(K, V)].partitioner - override def compute(split: Split, context: TaskContext) = + override def compute(split: Partition, context: TaskContext) = firstParent[(K, V)].iterator(split, context).map{ case (k, v) => (k, f(v)) } } @@ -660,9 +653,9 @@ private[spark] class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U]) extends RDD[(K, U)](prev) { - override def getSplits = firstParent[(K, V)].splits + override def getPartitions = firstParent[(K, V)].partitions override val partitioner = firstParent[(K, V)].partitioner - override def compute(split: Split, context: TaskContext) = { + override def compute(split: Partition, context: TaskContext) = { firstParent[(K, V)].iterator(split, context).flatMap { case (k, v) => f(v).map(x => (k, x)) } } } diff --git a/core/src/main/scala/spark/Split.scala b/core/src/main/scala/spark/Partition.scala similarity index 84% rename from core/src/main/scala/spark/Split.scala rename to core/src/main/scala/spark/Partition.scala index 90d4b47c553c535ff3271da0ddceb66a7ed832f9..e384308ef6e939c342f8e3f36af0115a87ac8838 100644 --- a/core/src/main/scala/spark/Split.scala +++ b/core/src/main/scala/spark/Partition.scala @@ -3,7 +3,7 @@ package spark /** * A partition of an RDD. */ -trait Split extends Serializable { +trait Partition extends Serializable { /** * Get the split's index within its parent RDD */ diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala index 9d5b966e1e5a092b542a2621aa2e91c17a9b299a..eec0e8dd79da4e70994a57c021077245de612e56 100644 --- a/core/src/main/scala/spark/Partitioner.scala +++ b/core/src/main/scala/spark/Partitioner.scala @@ -9,6 +9,38 @@ abstract class Partitioner extends Serializable { def getPartition(key: Any): Int } +object Partitioner { + + private val useDefaultParallelism = System.getProperty("spark.default.parallelism") != null + + /** + * Choose a partitioner to use for a cogroup-like operation between a number of RDDs. + * + * If any of the RDDs already has a partitioner, choose that one. + * + * Otherwise, we use a default HashPartitioner. For the number of partitions, if + * spark.default.parallelism is set, then we'll use the value from SparkContext + * defaultParallelism, otherwise we'll use the max number of upstream partitions. + * + * Unless spark.default.parallelism is set, He number of partitions will be the + * same as the number of partitions in the largest upstream RDD, as this should + * be least likely to cause out-of-memory errors. + * + * We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD. + */ + def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = { + val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.size).reverse + for (r <- bySize if r.partitioner != None) { + return r.partitioner.get + } + if (useDefaultParallelism) { + return new HashPartitioner(rdd.context.defaultParallelism) + } else { + return new HashPartitioner(bySize.head.partitions.size) + } + } +} + /** * A [[spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`. * diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 9d6ea782bd83c000f5c7d666018a2f2587e3759d..584efa8adf5f3877c914cf54b0fd22eefc325963 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -16,19 +16,22 @@ import org.apache.hadoop.mapred.TextOutputFormat import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} +import spark.Partitioner._ import spark.partial.BoundedDouble import spark.partial.CountEvaluator import spark.partial.GroupedCountEvaluator import spark.partial.PartialResult +import spark.rdd.CoalescedRDD import spark.rdd.CartesianRDD import spark.rdd.FilteredRDD import spark.rdd.FlatMappedRDD import spark.rdd.GlommedRDD import spark.rdd.MappedRDD import spark.rdd.MapPartitionsRDD -import spark.rdd.MapPartitionsWithSplitRDD +import spark.rdd.MapPartitionsWithIndexRDD import spark.rdd.PipedRDD import spark.rdd.SampledRDD +import spark.rdd.SubtractedRDD import spark.rdd.UnionRDD import spark.rdd.ZippedRDD import spark.storage.StorageLevel @@ -48,7 +51,7 @@ import SparkContext._ * * Internally, each RDD is characterized by five main properties: * - * - A list of splits (partitions) + * - A list of partitions * - A function for computing each split * - A list of dependencies on other RDDs * - Optionally, a Partitioner for key-value RDDs (e.g. to say that the RDD is hash-partitioned) @@ -75,13 +78,13 @@ abstract class RDD[T: ClassManifest]( // ======================================================================= /** Implemented by subclasses to compute a given partition. */ - def compute(split: Split, context: TaskContext): Iterator[T] + def compute(split: Partition, context: TaskContext): Iterator[T] /** * Implemented by subclasses to return the set of partitions in this RDD. This method will only * be called once, so it is safe to implement a time-consuming computation in it. */ - protected def getSplits: Array[Split] + protected def getPartitions: Array[Partition] /** * Implemented by subclasses to return how this RDD depends on parent RDDs. This method will only @@ -90,7 +93,7 @@ abstract class RDD[T: ClassManifest]( protected def getDependencies: Seq[Dependency[_]] = deps /** Optionally overridden by subclasses to specify placement preferences. */ - protected def getPreferredLocations(split: Split): Seq[String] = Nil + protected def getPreferredLocations(split: Partition): Seq[String] = Nil /** Optionally overridden by subclasses to specify how they are partitioned. */ val partitioner: Option[Partitioner] = None @@ -136,10 +139,10 @@ abstract class RDD[T: ClassManifest]( /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel - // Our dependencies and splits will be gotten by calling subclass's methods below, and will + // Our dependencies and partitions will be gotten by calling subclass's methods below, and will // be overwritten when we're checkpointed private var dependencies_ : Seq[Dependency[_]] = null - @transient private var splits_ : Array[Split] = null + @transient private var partitions_ : Array[Partition] = null /** An Option holding our checkpoint RDD, if we are checkpointed */ private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD) @@ -158,15 +161,15 @@ abstract class RDD[T: ClassManifest]( } /** - * Get the array of splits of this RDD, taking into account whether the + * Get the array of partitions of this RDD, taking into account whether the * RDD is checkpointed or not. */ - final def splits: Array[Split] = { - checkpointRDD.map(_.splits).getOrElse { - if (splits_ == null) { - splits_ = getSplits + final def partitions: Array[Partition] = { + checkpointRDD.map(_.partitions).getOrElse { + if (partitions_ == null) { + partitions_ = getPartitions } - splits_ + partitions_ } } @@ -174,7 +177,7 @@ abstract class RDD[T: ClassManifest]( * Get the preferred location of a split, taking into account whether the * RDD is checkpointed or not. */ - final def preferredLocations(split: Split): Seq[String] = { + final def preferredLocations(split: Partition): Seq[String] = { checkpointRDD.map(_.getPreferredLocations(split)).getOrElse { getPreferredLocations(split) } @@ -185,7 +188,7 @@ abstract class RDD[T: ClassManifest]( * This should ''not'' be called by users directly, but is available for implementors of custom * subclasses of RDD. */ - final def iterator(split: Split, context: TaskContext): Iterator[T] = { + final def iterator(split: Partition, context: TaskContext): Iterator[T] = { if (storageLevel != StorageLevel.NONE) { SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel) } else { @@ -196,7 +199,7 @@ abstract class RDD[T: ClassManifest]( /** * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing. */ - private[spark] def computeOrReadCheckpoint(split: Split, context: TaskContext): Iterator[T] = { + private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] = { if (isCheckpointed) { firstParent[T].iterator(split, context) } else { @@ -226,10 +229,15 @@ abstract class RDD[T: ClassManifest]( /** * Return a new RDD containing the distinct elements in this RDD. */ - def distinct(numSplits: Int): RDD[T] = - map(x => (x, null)).reduceByKey((x, y) => x, numSplits).map(_._1) + def distinct(numPartitions: Int): RDD[T] = + map(x => (x, null)).reduceByKey((x, y) => x, numPartitions).map(_._1) - def distinct(): RDD[T] = distinct(splits.size) + def distinct(): RDD[T] = distinct(partitions.size) + + /** + * Return a new RDD that is reduced into `numPartitions` partitions. + */ + def coalesce(numPartitions: Int): RDD[T] = new CoalescedRDD(this, numPartitions) /** * Return a sampled subset of this RDD. @@ -293,19 +301,26 @@ abstract class RDD[T: ClassManifest]( */ def cartesian[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other) + /** + * Return an RDD of grouped items. + */ + def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] = + groupBy[K](f, defaultPartitioner(this)) + /** * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements * mapping to that key. */ - def groupBy[K: ClassManifest](f: T => K, numSplits: Int): RDD[(K, Seq[T])] = { - val cleanF = sc.clean(f) - this.map(t => (cleanF(t), t)).groupByKey(numSplits) - } + def groupBy[K: ClassManifest](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] = + groupBy(f, new HashPartitioner(numPartitions)) /** * Return an RDD of grouped items. */ - def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] = groupBy[K](f, sc.defaultParallelism) + def groupBy[K: ClassManifest](f: T => K, p: Partitioner): RDD[(K, Seq[T])] = { + val cleanF = sc.clean(f) + this.map(t => (cleanF(t), t)).groupByKey(p) + } /** * Return an RDD created by piping elements to a forked external process. @@ -330,14 +345,24 @@ abstract class RDD[T: ClassManifest]( preservesPartitioning: Boolean = false): RDD[U] = new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning) - /** + /** * Return a new RDD by applying a function to each partition of this RDD, while tracking the index * of the original partition. */ + def mapPartitionsWithIndex[U: ClassManifest]( + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = + new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) + + /** + * Return a new RDD by applying a function to each partition of this RDD, while tracking the index + * of the original partition. + */ + @deprecated("use mapPartitionsWithIndex") def mapPartitionsWithSplit[U: ClassManifest]( f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = - new MapPartitionsWithSplitRDD(this, sc.clean(f), preservesPartitioning) + new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) /** * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, @@ -378,7 +403,27 @@ abstract class RDD[T: ClassManifest]( } /** - * Reduces the elements of this RDD using the specified associative binary operator. + * Return an RDD with the elements from `this` that are not in `other`. + * + * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting + * RDD will be <= us. + */ + def subtract(other: RDD[T]): RDD[T] = + subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size))) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + */ + def subtract(other: RDD[T], numPartitions: Int): RDD[T] = + subtract(other, new HashPartitioner(numPartitions)) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + */ + def subtract(other: RDD[T], p: Partitioner): RDD[T] = new SubtractedRDD[T](this, other, p) + + /** + * Reduces the elements of this RDD using the specified commutative and associative binary operator. */ def reduce(f: (T, T) => T): T = { val cleanF = sc.clean(f) @@ -465,7 +510,7 @@ abstract class RDD[T: ClassManifest]( } result } - val evaluator = new CountEvaluator(splits.size, confidence) + val evaluator = new CountEvaluator(partitions.size, confidence) sc.runApproximateJob(this, countElements, evaluator, timeout) } @@ -516,7 +561,7 @@ abstract class RDD[T: ClassManifest]( } map } - val evaluator = new GroupedCountEvaluator[T](splits.size, confidence) + val evaluator = new GroupedCountEvaluator[T](partitions.size, confidence) sc.runApproximateJob(this, countPartition, evaluator, timeout) } @@ -531,7 +576,7 @@ abstract class RDD[T: ClassManifest]( } val buf = new ArrayBuffer[T] var p = 0 - while (buf.size < num && p < splits.size) { + while (buf.size < num && p < partitions.size) { val left = num - buf.size val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, Array(p), true) buf ++= res(0) @@ -630,27 +675,32 @@ abstract class RDD[T: ClassManifest]( /** The [[spark.SparkContext]] that this RDD was created on. */ def context = sc + // Avoid handling doCheckpoint multiple times to prevent excessive recursion + private var doCheckpointCalled = false + /** * Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler * after a job using this RDD has completed (therefore the RDD has been materialized and * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs. */ private[spark] def doCheckpoint() { - if (checkpointData.isDefined) { - checkpointData.get.doCheckpoint() - } else { - dependencies.foreach(_.rdd.doCheckpoint()) + if (!doCheckpointCalled) { + doCheckpointCalled = true + if (checkpointData.isDefined) { + checkpointData.get.doCheckpoint() + } else { + dependencies.foreach(_.rdd.doCheckpoint()) + } } } /** * Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`) - * created from the checkpoint file, and forget its old dependencies and splits. + * created from the checkpoint file, and forget its old dependencies and partitions. */ private[spark] def markCheckpointed(checkpointRDD: RDD[_]) { clearDependencies() - dependencies_ = null - splits_ = null + partitions_ = null deps = null // Forget the constructor argument for dependencies too } @@ -665,15 +715,15 @@ abstract class RDD[T: ClassManifest]( } /** A description of this RDD and its recursive dependencies for debugging. */ - def toDebugString(): String = { + def toDebugString: String = { def debugString(rdd: RDD[_], prefix: String = ""): Seq[String] = { - Seq(prefix + rdd + " (" + rdd.splits.size + " splits)") ++ + Seq(prefix + rdd + " (" + rdd.partitions.size + " partitions)") ++ rdd.dependencies.flatMap(d => debugString(d.rdd, prefix + " ")) } debugString(this).mkString("\n") } - override def toString(): String = "%s%s[%d] at %s".format( + override def toString: String = "%s%s[%d] at %s".format( Option(name).map(_ + " ").getOrElse(""), getClass.getSimpleName, id, diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index a4a4ebaf53af83a98a072195e0f6cfa1fa28baf4..d00092e9845e2a888de7d7ca03ba23d90ae50ba9 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -16,7 +16,7 @@ private[spark] object CheckpointState extends Enumeration { /** * This class contains all the information related to RDD checkpointing. Each instance of this class * is associated with a RDD. It manages process of checkpointing of the associated RDD, as well as, - * manages the post-checkpoint state by providing the updated splits, iterator and preferred locations + * manages the post-checkpoint state by providing the updated partitions, iterator and preferred locations * of the checkpointed RDD. */ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) @@ -67,11 +67,11 @@ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _) val newRDD = new CheckpointRDD[T](rdd.context, path) - // Change the dependencies and splits of the RDD + // Change the dependencies and partitions of the RDD RDDCheckpointData.synchronized { cpFile = Some(path) cpRDD = Some(newRDD) - rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and splits + rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed RDDCheckpointData.clearTaskCaches() logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) @@ -79,15 +79,15 @@ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) } // Get preferred location of a split after checkpointing - def getPreferredLocations(split: Split): Seq[String] = { + def getPreferredLocations(split: Partition): Seq[String] = { RDDCheckpointData.synchronized { cpRDD.get.preferredLocations(split) } } - def getSplits: Array[Split] = { + def getPartitions: Array[Partition] = { RDDCheckpointData.synchronized { - cpRDD.get.splits + cpRDD.get.partitions } } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 0efc00d5dd338aff0832e1f40fde073858e1cd65..df23710d469f233ec48b44083dfa5d3781659e5e 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -39,7 +39,7 @@ import spark.broadcast._ import spark.deploy.LocalSparkCluster import spark.partial.ApproximateEvaluator import spark.partial.PartialResult -import rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD} +import rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD} import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler} import spark.scheduler.local.LocalScheduler import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} @@ -53,7 +53,7 @@ import storage.{StorageStatus, StorageUtils, RDDInfo} * cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster. * * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param jobName A name for your job, to display on the cluster web UI. + * @param appName A name for your application, to display on the cluster web UI. * @param sparkHome Location where Spark is installed on cluster nodes. * @param jars Collection of JARs to send to the cluster. These can be paths on the local file * system or HDFS, HTTP, HTTPS, or FTP URLs. @@ -61,7 +61,7 @@ import storage.{StorageStatus, StorageUtils, RDDInfo} */ class SparkContext( val master: String, - val jobName: String, + val appName: String, val sparkHome: String = null, val jars: Seq[String] = Nil, environment: Map[String, String] = Map()) @@ -143,7 +143,7 @@ class SparkContext( case SPARK_REGEX(sparkUrl) => val scheduler = new ClusterScheduler(this) - val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, jobName) + val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName) scheduler.initialize(backend) scheduler @@ -162,7 +162,7 @@ class SparkContext( val localCluster = new LocalSparkCluster( numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) val sparkUrl = localCluster.start() - val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, jobName) + val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName) scheduler.initialize(backend) backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => { localCluster.stop() @@ -178,9 +178,9 @@ class SparkContext( val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos:// val backend = if (coarseGrained) { - new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, jobName) + new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName) } else { - new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, jobName) + new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName) } scheduler.initialize(backend) scheduler @@ -216,7 +216,7 @@ class SparkContext( /** Distribute a local Scala collection to form an RDD. */ def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { - new ParallelCollection[T](this, seq, numSlices, Map[Int, Seq[String]]()) + new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) } /** Distribute a local Scala collection to form an RDD. */ @@ -229,7 +229,7 @@ class SparkContext( * Create a new partition for each collection item. */ def makeRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = { val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap - new ParallelCollection[T](this, seq.map(_._1), seq.size, indexToPrefs) + new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs) } /** @@ -439,7 +439,7 @@ class SparkContext( } /** - * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for + * Broadcast a read-only variable to the cluster, returning a [[spark.broadcast.Broadcast]] object for * reading it in distributed functions. The variable will be sent to each cluster only once. */ def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) @@ -614,14 +614,14 @@ class SparkContext( * Run a job on all partitions in an RDD and return the results in an array. */ def runJob[T, U: ClassManifest](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = { - runJob(rdd, func, 0 until rdd.splits.size, false) + runJob(rdd, func, 0 until rdd.partitions.size, false) } /** * Run a job on all partitions in an RDD and return the results in an array. */ def runJob[T, U: ClassManifest](rdd: RDD[T], func: Iterator[T] => U): Array[U] = { - runJob(rdd, func, 0 until rdd.splits.size, false) + runJob(rdd, func, 0 until rdd.partitions.size, false) } /** @@ -632,7 +632,7 @@ class SparkContext( processPartition: (TaskContext, Iterator[T]) => U, resultHandler: (Int, U) => Unit) { - runJob[T, U](rdd, processPartition, 0 until rdd.splits.size, false, resultHandler) + runJob[T, U](rdd, processPartition, 0 until rdd.partitions.size, false, resultHandler) } /** @@ -644,7 +644,7 @@ class SparkContext( resultHandler: (Int, U) => Unit) { val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter) - runJob[T, U](rdd, processFunc, 0 until rdd.splits.size, false, resultHandler) + runJob[T, U](rdd, processFunc, 0 until rdd.partitions.size, false, resultHandler) } /** @@ -693,10 +693,10 @@ class SparkContext( checkpointDir = Some(dir) } - /** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */ + /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */ def defaultParallelism: Int = taskScheduler.defaultParallelism - /** Default min number of splits for Hadoop RDDs when not given by user */ + /** Default min number of partitions for Hadoop RDDs when not given by user */ def defaultMinSplits: Int = math.min(defaultParallelism, 2) private var nextShuffleId = new AtomicInteger(0) diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 28d643abca8f42686f7174cff3aad0d6dd338285..81daacf958b5a03d3135f4587f2d6e411b0c6a3c 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -454,4 +454,25 @@ private object Utils extends Logging { def clone[T](value: T, serializer: SerializerInstance): T = { serializer.deserialize[T](serializer.serialize(value)) } + + /** + * Detect whether this thread might be executing a shutdown hook. Will always return true if + * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g. + * if System.exit was just called by a concurrent thread). + * + * Currently, this detects whether the JVM is shutting down by Runtime#addShutdownHook throwing + * an IllegalStateException. + */ + def inShutdown(): Boolean = { + try { + val hook = new Thread { + override def run() {} + } + Runtime.getRuntime.addShutdownHook(hook) + Runtime.getRuntime.removeShutdownHook(hook) + } catch { + case ise: IllegalStateException => return true + } + return false + } } diff --git a/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala index 843e1bd18bdbf766819108b9abcc5eec09afa3f5..ba00b6a8448f1d28d2bd4d257aca9a62db8b7539 100644 --- a/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala @@ -6,8 +6,8 @@ import spark.api.java.function.{Function => JFunction} import spark.util.StatCounter import spark.partial.{BoundedDouble, PartialResult} import spark.storage.StorageLevel - import java.lang.Double +import spark.Partitioner class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, JavaDoubleRDD] { @@ -44,7 +44,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav /** * Return a new RDD containing the distinct elements in this RDD. */ - def distinct(numSplits: Int): JavaDoubleRDD = fromRDD(srdd.distinct(numSplits)) + def distinct(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.distinct(numPartitions)) /** * Return a new RDD containing only the elements that satisfy a predicate. @@ -52,6 +52,32 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav def filter(f: JFunction[Double, java.lang.Boolean]): JavaDoubleRDD = fromRDD(srdd.filter(x => f(x).booleanValue())) + /** + * Return a new RDD that is reduced into `numPartitions` partitions. + */ + def coalesce(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.coalesce(numPartitions)) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + * + * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting + * RDD will be <= us. + */ + def subtract(other: JavaDoubleRDD): JavaDoubleRDD = + fromRDD(srdd.subtract(other)) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + */ + def subtract(other: JavaDoubleRDD, numPartitions: Int): JavaDoubleRDD = + fromRDD(srdd.subtract(other, numPartitions)) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + */ + def subtract(other: JavaDoubleRDD, p: Partitioner): JavaDoubleRDD = + fromRDD(srdd.subtract(other, p)) + /** * Return a sampled subset of this RDD. */ diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala index 8ce32e0e2fd21b57ef28f9c8c044ef2727be64e8..c1bd13c49a9e64a0dcbbad35fa2c1e9f02c32d81 100644 --- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala @@ -19,6 +19,7 @@ import spark.OrderedRDDFunctions import spark.storage.StorageLevel import spark.HashPartitioner import spark.Partitioner +import spark.Partitioner._ import spark.RDD import spark.SparkContext.rddToPairRDDFunctions @@ -54,14 +55,19 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif /** * Return a new RDD containing the distinct elements in this RDD. */ - def distinct(numSplits: Int): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct(numSplits)) + def distinct(numPartitions: Int): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct(numPartitions)) /** * Return a new RDD containing only the elements that satisfy a predicate. */ - def filter(f: Function[(K, V), java.lang.Boolean]): JavaPairRDD[K, V] = + def filter(f: JFunction[(K, V), java.lang.Boolean]): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.filter(x => f(x).booleanValue())) + /** + * Return a new RDD that is reduced into `numPartitions` partitions. + */ + def coalesce(numPartitions: Int): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.coalesce(numPartitions)) + /** * Return a sampled subset of this RDD. */ @@ -97,7 +103,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif * In addition, users can control the partitioning of the output RDD, and whether to perform * map-side aggregation (if a mapper can produce multiple items with the same key). */ - def combineByKey[C](createCombiner: Function[V, C], + def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], mergeCombiners: JFunction2[C, C, C], partitioner: Partitioner): JavaPairRDD[K, C] = { @@ -117,8 +123,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], mergeCombiners: JFunction2[C, C, C], - numSplits: Int): JavaPairRDD[K, C] = - combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numSplits)) + numPartitions: Int): JavaPairRDD[K, C] = + combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions)) /** * Merge the values for each key using an associative reduce function. This will also perform @@ -157,10 +163,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif /** * Merge the values for each key using an associative reduce function. This will also perform * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with numSplits splits. + * "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions. */ - def reduceByKey(func: JFunction2[V, V, V], numSplits: Int): JavaPairRDD[K, V] = - fromRDD(rdd.reduceByKey(func, numSplits)) + def reduceByKey(func: JFunction2[V, V, V], numPartitions: Int): JavaPairRDD[K, V] = + fromRDD(rdd.reduceByKey(func, numPartitions)) /** * Group the values for each key in the RDD into a single sequence. Allows controlling the @@ -171,10 +177,31 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif /** * Group the values for each key in the RDD into a single sequence. Hash-partitions the - * resulting RDD with into `numSplits` partitions. + * resulting RDD with into `numPartitions` partitions. + */ + def groupByKey(numPartitions: Int): JavaPairRDD[K, JList[V]] = + fromRDD(groupByResultToJava(rdd.groupByKey(numPartitions))) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + * + * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting + * RDD will be <= us. */ - def groupByKey(numSplits: Int): JavaPairRDD[K, JList[V]] = - fromRDD(groupByResultToJava(rdd.groupByKey(numSplits))) + def subtract(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = + fromRDD(rdd.subtract(other)) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + */ + def subtract(other: JavaPairRDD[K, V], numPartitions: Int): JavaPairRDD[K, V] = + fromRDD(rdd.subtract(other, numPartitions)) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + */ + def subtract(other: JavaPairRDD[K, V], p: Partitioner): JavaPairRDD[K, V] = + fromRDD(rdd.subtract(other, p)) /** * Return a copy of the RDD partitioned using the specified partitioner. If `mapSideCombine` @@ -215,30 +242,30 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif fromRDD(rdd.rightOuterJoin(other, partitioner)) /** - * Simplified version of combineByKey that hash-partitions the resulting RDD using the default - * parallelism level. + * Simplified version of combineByKey that hash-partitions the resulting RDD using the existing + * partitioner/parallelism level. */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], mergeCombiners: JFunction2[C, C, C]): JavaPairRDD[K, C] = { implicit val cm: ClassManifest[C] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]] - fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners)) + fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(rdd))) } /** * Merge the values for each key using an associative reduce function. This will also perform * the merging locally on each mapper before sending results to a reducer, similarly to a - * "combiner" in MapReduce. Output will be hash-partitioned with the default parallelism level. + * "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/ + * parallelism level. */ def reduceByKey(func: JFunction2[V, V, V]): JavaPairRDD[K, V] = { - val partitioner = rdd.defaultPartitioner(rdd) - fromRDD(reduceByKey(partitioner, func)) + fromRDD(reduceByKey(defaultPartitioner(rdd), func)) } /** * Group the values for each key in the RDD into a single sequence. Hash-partitions the - * resulting RDD with the default parallelism level. + * resulting RDD with the existing partitioner/parallelism level. */ def groupByKey(): JavaPairRDD[K, JList[V]] = fromRDD(groupByResultToJava(rdd.groupByKey())) @@ -256,14 +283,14 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and * (k, v2) is in `other`. Performs a hash join across the cluster. */ - def join[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (V, W)] = - fromRDD(rdd.join(other, numSplits)) + def join[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (V, W)] = + fromRDD(rdd.join(other, numPartitions)) /** * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output - * using the default level of parallelism. + * using the existing partitioner/parallelism level. */ def leftOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, Option[W])] = fromRDD(rdd.leftOuterJoin(other)) @@ -272,16 +299,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the * resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output - * into `numSplits` partitions. + * into `numPartitions` partitions. */ - def leftOuterJoin[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (V, Option[W])] = - fromRDD(rdd.leftOuterJoin(other, numSplits)) + def leftOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (V, Option[W])] = + fromRDD(rdd.leftOuterJoin(other, numPartitions)) /** * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the * resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting - * RDD using the default parallelism level. + * RDD using the existing partitioner/parallelism level. */ def rightOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (Option[V], W)] = fromRDD(rdd.rightOuterJoin(other)) @@ -292,8 +319,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting * RDD into the given number of partitions. */ - def rightOuterJoin[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (Option[V], W)] = - fromRDD(rdd.rightOuterJoin(other, numSplits)) + def rightOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (Option[V], W)] = + fromRDD(rdd.rightOuterJoin(other, numPartitions)) /** * Return the key-value pairs in this RDD to the master as a Map. @@ -304,7 +331,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif * Pass each value in the key-value pair RDD through a map function without changing the keys; * this also retains the original RDD's partitioning. */ - def mapValues[U](f: Function[V, U]): JavaPairRDD[K, U] = { + def mapValues[U](f: JFunction[V, U]): JavaPairRDD[K, U] = { implicit val cm: ClassManifest[U] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] fromRDD(rdd.mapValues(f)) @@ -357,16 +384,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. */ - def cogroup[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (JList[V], JList[W])] - = fromRDD(cogroupResultToJava(rdd.cogroup(other, numSplits))) + def cogroup[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (JList[V], JList[W])] + = fromRDD(cogroupResultToJava(rdd.cogroup(other, numPartitions))) /** * For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a * tuple with the list of values for that key in `this`, `other1` and `other2`. */ - def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], numSplits: Int) + def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], numPartitions: Int) : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] = - fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numSplits))) + fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numPartitions))) /** Alias for cogroup. */ def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] = @@ -447,7 +474,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif */ def sortByKey(ascending: Boolean): JavaPairRDD[K, V] = { val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]] - sortByKey(comp, true) + sortByKey(comp, ascending) } /** diff --git a/core/src/main/scala/spark/api/java/JavaRDD.scala b/core/src/main/scala/spark/api/java/JavaRDD.scala index ac31350ec3374e6a6396087b2c37681ecfe79f81..301688889898e169e52b75951c159fb1b7a3159d 100644 --- a/core/src/main/scala/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/spark/api/java/JavaRDD.scala @@ -30,7 +30,7 @@ JavaRDDLike[T, JavaRDD[T]] { /** * Return a new RDD containing the distinct elements in this RDD. */ - def distinct(numSplits: Int): JavaRDD[T] = wrapRDD(rdd.distinct(numSplits)) + def distinct(numPartitions: Int): JavaRDD[T] = wrapRDD(rdd.distinct(numPartitions)) /** * Return a new RDD containing only the elements that satisfy a predicate. @@ -38,6 +38,11 @@ JavaRDDLike[T, JavaRDD[T]] { def filter(f: JFunction[T, java.lang.Boolean]): JavaRDD[T] = wrapRDD(rdd.filter((x => f(x).booleanValue()))) + /** + * Return a new RDD that is reduced into `numPartitions` partitions. + */ + def coalesce(numPartitions: Int): JavaRDD[T] = rdd.coalesce(numPartitions) + /** * Return a sampled subset of this RDD. */ @@ -50,6 +55,26 @@ JavaRDDLike[T, JavaRDD[T]] { */ def union(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.union(other.rdd)) + /** + * Return an RDD with the elements from `this` that are not in `other`. + * + * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting + * RDD will be <= us. + */ + def subtract(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.subtract(other)) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + */ + def subtract(other: JavaRDD[T], numPartitions: Int): JavaRDD[T] = + wrapRDD(rdd.subtract(other, numPartitions)) + + /** + * Return an RDD with the elements from `this` that are not in `other`. + */ + def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] = + wrapRDD(rdd.subtract(other, p)) + } object JavaRDD { diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index 60025b459c383168f694b636334e85663a5b1522..d884529d7a6f552227deb3989912efeff13cd5f2 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -4,7 +4,7 @@ import java.util.{List => JList} import scala.Tuple2 import scala.collection.JavaConversions._ -import spark.{SparkContext, Split, RDD, TaskContext} +import spark.{SparkContext, Partition, RDD, TaskContext} import spark.api.java.JavaPairRDD._ import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} import spark.partial.{PartialResult, BoundedDouble} @@ -12,7 +12,7 @@ import spark.storage.StorageLevel import com.google.common.base.Optional -trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround[T] { +trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def wrapRDD(rdd: RDD[T]): This implicit val classManifest: ClassManifest[T] @@ -20,7 +20,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround def rdd: RDD[T] /** Set of partitions in this RDD. */ - def splits: JList[Split] = new java.util.ArrayList(rdd.splits.toSeq) + def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq) /** The [[spark.SparkContext]] that this RDD was created on. */ def context: SparkContext = rdd.context @@ -36,7 +36,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround * This should ''not'' be called by users directly, but is available for implementors of custom * subclasses of RDD. */ - def iterator(split: Split, taskContext: TaskContext): java.util.Iterator[T] = + def iterator(split: Partition, taskContext: TaskContext): java.util.Iterator[T] = asJavaIterator(rdd.iterator(split, taskContext)) // Transformations (return a new RDD) @@ -82,12 +82,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround } /** - * Part of the workaround for SPARK-668; called in PairFlatMapWorkaround.java. + * Return a new RDD by first applying a function to all elements of this + * RDD, and then flattening the results. */ - private[spark] def doFlatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = { + def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { import scala.collection.JavaConverters._ def fn = (x: T) => f.apply(x).asScala - def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] + def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]] JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType()) } @@ -110,8 +111,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround /** * Return a new RDD by applying a function to each partition of this RDD. */ - def mapPartitions[K, V](f: PairFlatMapFunction[java.util.Iterator[T], K, V]): - JavaPairRDD[K, V] = { + def mapPartitions[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]): + JavaPairRDD[K2, V2] = { def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(f.keyType(), f.valueType()) } @@ -146,12 +147,12 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements * mapping to that key. */ - def groupBy[K](f: JFunction[T, K], numSplits: Int): JavaPairRDD[K, JList[T]] = { + def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JList[T]] = { implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] implicit val vcm: ClassManifest[JList[T]] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]] - JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numSplits)(f.returnType)))(kcm, vcm) + JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(f.returnType)))(kcm, vcm) } /** @@ -201,7 +202,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround } /** - * Reduces the elements of this RDD using the specified associative binary operator. + * Reduces the elements of this RDD using the specified commutative and associative binary operator. */ def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f) @@ -333,6 +334,6 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround /** A description of this RDD and its recursive dependencies for debugging. */ def toDebugString(): String = { - rdd.toDebugString() + rdd.toDebugString } } diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index 50b8970cd8aad362781d6bfe237a1a6e2540c0f5..f75fc27c7b2f63d024aea676bae8df9de55ded7e 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -23,41 +23,41 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork /** * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param jobName A name for your job, to display on the cluster web UI + * @param appName A name for your application, to display on the cluster web UI */ - def this(master: String, jobName: String) = this(new SparkContext(master, jobName)) + def this(master: String, appName: String) = this(new SparkContext(master, appName)) /** * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param jobName A name for your job, to display on the cluster web UI + * @param appName A name for your application, to display on the cluster web UI * @param sparkHome The SPARK_HOME directory on the slave nodes * @param jars Collection of JARs to send to the cluster. These can be paths on the local file * system or HDFS, HTTP, HTTPS, or FTP URLs. */ - def this(master: String, jobName: String, sparkHome: String, jarFile: String) = - this(new SparkContext(master, jobName, sparkHome, Seq(jarFile))) + def this(master: String, appName: String, sparkHome: String, jarFile: String) = + this(new SparkContext(master, appName, sparkHome, Seq(jarFile))) /** * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param jobName A name for your job, to display on the cluster web UI + * @param appName A name for your application, to display on the cluster web UI * @param sparkHome The SPARK_HOME directory on the slave nodes * @param jars Collection of JARs to send to the cluster. These can be paths on the local file * system or HDFS, HTTP, HTTPS, or FTP URLs. */ - def this(master: String, jobName: String, sparkHome: String, jars: Array[String]) = - this(new SparkContext(master, jobName, sparkHome, jars.toSeq)) + def this(master: String, appName: String, sparkHome: String, jars: Array[String]) = + this(new SparkContext(master, appName, sparkHome, jars.toSeq)) /** * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param jobName A name for your job, to display on the cluster web UI + * @param appName A name for your application, to display on the cluster web UI * @param sparkHome The SPARK_HOME directory on the slave nodes * @param jars Collection of JARs to send to the cluster. These can be paths on the local file * system or HDFS, HTTP, HTTPS, or FTP URLs. * @param environment Environment variables to set on worker nodes */ - def this(master: String, jobName: String, sparkHome: String, jars: Array[String], + def this(master: String, appName: String, sparkHome: String, jars: Array[String], environment: JMap[String, String]) = - this(new SparkContext(master, jobName, sparkHome, jars.toSeq, environment)) + this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment)) private[spark] val env = sc.env diff --git a/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java b/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java deleted file mode 100644 index 68b6fd6622742148761363045c6a06d0e8afeb74..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java +++ /dev/null @@ -1,20 +0,0 @@ -package spark.api.java; - -import spark.api.java.JavaPairRDD; -import spark.api.java.JavaRDDLike; -import spark.api.java.function.PairFlatMapFunction; - -import java.io.Serializable; - -/** - * Workaround for SPARK-668. - */ -class PairFlatMapWorkaround<T> implements Serializable { - /** - * Return a new RDD by first applying a function to all elements of this - * RDD, and then flattening the results. - */ - public <K, V> JavaPairRDD<K, V> flatMap(PairFlatMapFunction<T, K, V> f) { - return ((JavaRDDLike <T, ?>) this).doFlatMap(f); - } -} diff --git a/core/src/main/scala/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/spark/api/python/PythonPartitioner.scala index 519e31032304e16522319ae103ab417aa05348f4..d618c098c2bed7f74773c3ba028319fbfcd7f070 100644 --- a/core/src/main/scala/spark/api/python/PythonPartitioner.scala +++ b/core/src/main/scala/spark/api/python/PythonPartitioner.scala @@ -9,7 +9,7 @@ import java.util.Arrays * * Stores the unique id() of the Python-side partitioning function so that it is incorporated into * equality comparisons. Correctness requires that the id is a unique identifier for the - * lifetime of the job (i.e. that it is not re-used as the id of a different partitioning + * lifetime of the program (i.e. that it is not re-used as the id of a different partitioning * function). This can be ensured by using the Python id() function and maintaining a reference * to the Python partitioning function so that its id() is not reused. */ diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index ab8351e55e9efa614533fa2cd93947efe688b2b5..8c734773847b5d5ec12506b34dcd9e577b9cf930 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -32,11 +32,11 @@ private[spark] class PythonRDD[T: ClassManifest]( this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec, broadcastVars, accumulator) - override def getSplits = parent.splits + override def getPartitions = parent.partitions override val partitioner = if (preservePartitoning) parent.partitioner else None - override def compute(split: Split, context: TaskContext): Iterator[Array[Byte]] = { + override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py")) @@ -65,7 +65,7 @@ private[spark] class PythonRDD[T: ClassManifest]( SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) val dOut = new DataOutputStream(proc.getOutputStream) - // Split index + // Partition index dOut.writeInt(split.index) // sparkFilesDir PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut) @@ -155,8 +155,8 @@ private class PythonException(msg: String) extends Exception(msg) */ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Array[Byte], Array[Byte])](prev) { - override def getSplits = prev.splits - override def compute(split: Split, context: TaskContext) = + override def getPartitions = prev.partitions + override def compute(split: Partition, context: TaskContext) = prev.iterator(split, context).grouped(2).map { case Seq(a, b) => (a, b) case x => throw new Exception("PairwiseRDD: unexpected value: " + x) diff --git a/core/src/main/scala/spark/deploy/JobDescription.scala b/core/src/main/scala/spark/deploy/ApplicationDescription.scala similarity index 66% rename from core/src/main/scala/spark/deploy/JobDescription.scala rename to core/src/main/scala/spark/deploy/ApplicationDescription.scala index 7160fc05fc822af0bb62f73d5a54341af194f4cb..6659e53b25f370f1b6747a03b7a42ec147b121de 100644 --- a/core/src/main/scala/spark/deploy/JobDescription.scala +++ b/core/src/main/scala/spark/deploy/ApplicationDescription.scala @@ -1,6 +1,6 @@ package spark.deploy -private[spark] class JobDescription( +private[spark] class ApplicationDescription( val name: String, val cores: Int, val memoryPerSlave: Int, @@ -10,5 +10,5 @@ private[spark] class JobDescription( val user = System.getProperty("user.name", "<unknown>") - override def toString: String = "JobDescription(" + name + ")" + override def toString: String = "ApplicationDescription(" + name + ")" } diff --git a/core/src/main/scala/spark/deploy/DeployMessage.scala b/core/src/main/scala/spark/deploy/DeployMessage.scala index 35f40c6e91fcfaa937fc815dd12c7401959bd11c..3cbf4fdd98be01a4424a83c18ca4aa784c022697 100644 --- a/core/src/main/scala/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/spark/deploy/DeployMessage.scala @@ -1,7 +1,7 @@ package spark.deploy import spark.deploy.ExecutorState.ExecutorState -import spark.deploy.master.{WorkerInfo, JobInfo} +import spark.deploy.master.{WorkerInfo, ApplicationInfo} import spark.deploy.worker.ExecutorRunner import scala.collection.immutable.List @@ -23,37 +23,39 @@ case class RegisterWorker( private[spark] case class ExecutorStateChanged( - jobId: String, + appId: String, execId: Int, state: ExecutorState, message: Option[String], exitStatus: Option[Int]) extends DeployMessage +private[spark] case class Heartbeat(workerId: String) extends DeployMessage + // Master to Worker private[spark] case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage private[spark] case class RegisterWorkerFailed(message: String) extends DeployMessage -private[spark] case class KillExecutor(jobId: String, execId: Int) extends DeployMessage +private[spark] case class KillExecutor(appId: String, execId: Int) extends DeployMessage private[spark] case class LaunchExecutor( - jobId: String, + appId: String, execId: Int, - jobDesc: JobDescription, + appDesc: ApplicationDescription, cores: Int, memory: Int, sparkHome: String) extends DeployMessage - // Client to Master -private[spark] case class RegisterJob(jobDescription: JobDescription) extends DeployMessage +private[spark] case class RegisterApplication(appDescription: ApplicationDescription) + extends DeployMessage // Master to Client private[spark] -case class RegisteredJob(jobId: String) extends DeployMessage +case class RegisteredApplication(appId: String) extends DeployMessage private[spark] case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) @@ -63,7 +65,7 @@ case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String exitStatus: Option[Int]) private[spark] -case class JobKilled(message: String) +case class appKilled(message: String) // Internal message in Client @@ -76,8 +78,11 @@ private[spark] case object RequestMasterState // Master to MasterWebUI private[spark] -case class MasterState(uri: String, workers: Array[WorkerInfo], activeJobs: Array[JobInfo], - completedJobs: Array[JobInfo]) +case class MasterState(host: String, port: Int, workers: Array[WorkerInfo], + activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) { + + def uri = "spark://" + host + ":" + port +} // WorkerWebUI to Worker private[spark] case object RequestWorkerState @@ -85,6 +90,6 @@ private[spark] case object RequestWorkerState // Worker to WorkerWebUI private[spark] -case class WorkerState(uri: String, workerId: String, executors: List[ExecutorRunner], +case class WorkerState(host: String, port: Int, workerId: String, executors: List[ExecutorRunner], finishedExecutors: List[ExecutorRunner], masterUrl: String, cores: Int, memory: Int, coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) diff --git a/core/src/main/scala/spark/deploy/JsonProtocol.scala b/core/src/main/scala/spark/deploy/JsonProtocol.scala index 732fa080645b0e7752f507f528d539e67b8f8ddf..38a6ebfc242c163a1f5be26c2c46dbacfe45da12 100644 --- a/core/src/main/scala/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/spark/deploy/JsonProtocol.scala @@ -1,6 +1,6 @@ package spark.deploy -import master.{JobInfo, WorkerInfo} +import master.{ApplicationInfo, WorkerInfo} import worker.ExecutorRunner import cc.spray.json._ @@ -20,8 +20,8 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol { ) } - implicit object JobInfoJsonFormat extends RootJsonWriter[JobInfo] { - def write(obj: JobInfo) = JsObject( + implicit object AppInfoJsonFormat extends RootJsonWriter[ApplicationInfo] { + def write(obj: ApplicationInfo) = JsObject( "starttime" -> JsNumber(obj.startTime), "id" -> JsString(obj.id), "name" -> JsString(obj.desc.name), @@ -31,8 +31,8 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol { "submitdate" -> JsString(obj.submitDate.toString)) } - implicit object JobDescriptionJsonFormat extends RootJsonWriter[JobDescription] { - def write(obj: JobDescription) = JsObject( + implicit object AppDescriptionJsonFormat extends RootJsonWriter[ApplicationDescription] { + def write(obj: ApplicationDescription) = JsObject( "name" -> JsString(obj.name), "cores" -> JsNumber(obj.cores), "memoryperslave" -> JsNumber(obj.memoryPerSlave), @@ -44,8 +44,8 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol { def write(obj: ExecutorRunner) = JsObject( "id" -> JsNumber(obj.execId), "memory" -> JsNumber(obj.memory), - "jobid" -> JsString(obj.jobId), - "jobdesc" -> obj.jobDesc.toJson.asJsObject + "appid" -> JsString(obj.appId), + "appdesc" -> obj.appDesc.toJson.asJsObject ) } @@ -57,8 +57,8 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol { "coresused" -> JsNumber(obj.workers.map(_.coresUsed).sum), "memory" -> JsNumber(obj.workers.map(_.memory).sum), "memoryused" -> JsNumber(obj.workers.map(_.memoryUsed).sum), - "activejobs" -> JsArray(obj.activeJobs.toList.map(_.toJson)), - "completedjobs" -> JsArray(obj.completedJobs.toList.map(_.toJson)) + "activeapps" -> JsArray(obj.activeApps.toList.map(_.toJson)), + "completedapps" -> JsArray(obj.completedApps.toList.map(_.toJson)) ) } diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala index a63eee12339d12f8ac3b3b236b54d6cec603e59a..1a95524cf9c496734a88f41fd08c35fcd10d874f 100644 --- a/core/src/main/scala/spark/deploy/client/Client.scala +++ b/core/src/main/scala/spark/deploy/client/Client.scala @@ -8,25 +8,25 @@ import akka.pattern.AskTimeoutException import spark.{SparkException, Logging} import akka.remote.RemoteClientLifeCycleEvent import akka.remote.RemoteClientShutdown -import spark.deploy.RegisterJob +import spark.deploy.RegisterApplication import spark.deploy.master.Master import akka.remote.RemoteClientDisconnected import akka.actor.Terminated import akka.dispatch.Await /** - * The main class used to talk to a Spark deploy cluster. Takes a master URL, a job description, - * and a listener for job events, and calls back the listener when various events occur. + * The main class used to talk to a Spark deploy cluster. Takes a master URL, an app description, + * and a listener for cluster events, and calls back the listener when various events occur. */ private[spark] class Client( actorSystem: ActorSystem, masterUrl: String, - jobDescription: JobDescription, + appDescription: ApplicationDescription, listener: ClientListener) extends Logging { var actor: ActorRef = null - var jobId: String = null + var appId: String = null class ClientActor extends Actor with Logging { var master: ActorRef = null @@ -38,7 +38,7 @@ private[spark] class Client( try { master = context.actorFor(Master.toAkkaUrl(masterUrl)) masterAddress = master.path.address - master ! RegisterJob(jobDescription) + master ! RegisterApplication(appDescription) context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) context.watch(master) // Doesn't work with remote actors, but useful for testing } catch { @@ -50,17 +50,17 @@ private[spark] class Client( } override def receive = { - case RegisteredJob(jobId_) => - jobId = jobId_ - listener.connected(jobId) + case RegisteredApplication(appId_) => + appId = appId_ + listener.connected(appId) case ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) => - val fullId = jobId + "/" + id + val fullId = appId + "/" + id logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, host, cores)) listener.executorAdded(fullId, workerId, host, cores, memory) case ExecutorUpdated(id, state, message, exitStatus) => - val fullId = jobId + "/" + id + val fullId = appId + "/" + id val messageText = message.map(s => " (" + s + ")").getOrElse("") logInfo("Executor updated: %s is now %s%s".format(fullId, state, messageText)) if (ExecutorState.isFinished(state)) { @@ -107,7 +107,7 @@ private[spark] class Client( def stop() { if (actor != null) { try { - val timeout = 1.seconds + val timeout = 5.seconds val future = actor.ask(StopClient)(timeout) Await.result(future, timeout) } catch { diff --git a/core/src/main/scala/spark/deploy/client/ClientListener.scala b/core/src/main/scala/spark/deploy/client/ClientListener.scala index 7035f4b3942429e2bbea516fcd544d31ba682d36..b7008321df564976d37ca9f428d7d47920a30bf3 100644 --- a/core/src/main/scala/spark/deploy/client/ClientListener.scala +++ b/core/src/main/scala/spark/deploy/client/ClientListener.scala @@ -8,7 +8,7 @@ package spark.deploy.client * Users of this API should *not* block inside the callback methods. */ private[spark] trait ClientListener { - def connected(jobId: String): Unit + def connected(appId: String): Unit def disconnected(): Unit diff --git a/core/src/main/scala/spark/deploy/client/TestClient.scala b/core/src/main/scala/spark/deploy/client/TestClient.scala index 8764c400e26d8d0583a76cce9a74daba501457f3..dc004b59ca5ac247d4e7c8125b775a1f1698e7ea 100644 --- a/core/src/main/scala/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/spark/deploy/client/TestClient.scala @@ -2,13 +2,13 @@ package spark.deploy.client import spark.util.AkkaUtils import spark.{Logging, Utils} -import spark.deploy.{Command, JobDescription} +import spark.deploy.{Command, ApplicationDescription} private[spark] object TestClient { class TestListener extends ClientListener with Logging { def connected(id: String) { - logInfo("Connected to master, got job ID " + id) + logInfo("Connected to master, got app ID " + id) } def disconnected() { @@ -24,7 +24,7 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0) - val desc = new JobDescription( + val desc = new ApplicationDescription( "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home") val listener = new TestListener val client = new Client(actorSystem, url, desc, listener) diff --git a/core/src/main/scala/spark/deploy/master/JobInfo.scala b/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala similarity index 84% rename from core/src/main/scala/spark/deploy/master/JobInfo.scala rename to core/src/main/scala/spark/deploy/master/ApplicationInfo.scala index a274b21c346f2da8605cf1d5b39f9e1efaf059c3..3591a9407237a765003af9d59d99603c70cf06a8 100644 --- a/core/src/main/scala/spark/deploy/master/JobInfo.scala +++ b/core/src/main/scala/spark/deploy/master/ApplicationInfo.scala @@ -1,18 +1,18 @@ package spark.deploy.master -import spark.deploy.JobDescription +import spark.deploy.ApplicationDescription import java.util.Date import akka.actor.ActorRef import scala.collection.mutable -private[spark] class JobInfo( +private[spark] class ApplicationInfo( val startTime: Long, val id: String, - val desc: JobDescription, + val desc: ApplicationDescription, val submitDate: Date, val driver: ActorRef) { - var state = JobState.WAITING + var state = ApplicationState.WAITING var executors = new mutable.HashMap[Int, ExecutorInfo] var coresGranted = 0 var endTime = -1L @@ -48,7 +48,7 @@ private[spark] class JobInfo( _retryCount } - def markFinished(endState: JobState.Value) { + def markFinished(endState: ApplicationState.Value) { state = endState endTime = System.currentTimeMillis() } diff --git a/core/src/main/scala/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/spark/deploy/master/ApplicationState.scala new file mode 100644 index 0000000000000000000000000000000000000000..15016b388d2d82aaf12ce936122ede712c38bade --- /dev/null +++ b/core/src/main/scala/spark/deploy/master/ApplicationState.scala @@ -0,0 +1,11 @@ +package spark.deploy.master + +private[spark] object ApplicationState + extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED") { + + type ApplicationState = Value + + val WAITING, RUNNING, FINISHED, FAILED = Value + + val MAX_NUM_RETRY = 10 +} diff --git a/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala b/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala index 1db2c326333f1751f22ff83d215b277020747718..48e6055fb572aefd1b71b5bf40fd838d86fe4b0a 100644 --- a/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala +++ b/core/src/main/scala/spark/deploy/master/ExecutorInfo.scala @@ -4,12 +4,12 @@ import spark.deploy.ExecutorState private[spark] class ExecutorInfo( val id: Int, - val job: JobInfo, + val application: ApplicationInfo, val worker: WorkerInfo, val cores: Int, val memory: Int) { var state = ExecutorState.LAUNCHING - def fullId: String = job.id + "/" + id + def fullId: String = application.id + "/" + id } diff --git a/core/src/main/scala/spark/deploy/master/JobState.scala b/core/src/main/scala/spark/deploy/master/JobState.scala deleted file mode 100644 index 2b70cf01918d7b971fccd4c6859536ff8c323e40..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/deploy/master/JobState.scala +++ /dev/null @@ -1,9 +0,0 @@ -package spark.deploy.master - -private[spark] object JobState extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED") { - type JobState = Value - - val WAITING, RUNNING, FINISHED, FAILED = Value - - val MAX_NUM_RETRY = 10 -} diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 92e7914b1b931a01e4237d497821306ea2989463..b7f167425f5a70b83a4924f1f9bf804487893fe1 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -3,6 +3,7 @@ package spark.deploy.master import akka.actor._ import akka.actor.Terminated import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, RemoteClientShutdown} +import akka.util.duration._ import java.text.SimpleDateFormat import java.util.Date @@ -15,21 +16,24 @@ import spark.util.AkkaUtils private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging { - val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For job IDs + val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000 - var nextJobNumber = 0 + var nextAppNumber = 0 val workers = new HashSet[WorkerInfo] val idToWorker = new HashMap[String, WorkerInfo] val actorToWorker = new HashMap[ActorRef, WorkerInfo] val addressToWorker = new HashMap[Address, WorkerInfo] - val jobs = new HashSet[JobInfo] - val idToJob = new HashMap[String, JobInfo] - val actorToJob = new HashMap[ActorRef, JobInfo] - val addressToJob = new HashMap[Address, JobInfo] + val apps = new HashSet[ApplicationInfo] + val idToApp = new HashMap[String, ApplicationInfo] + val actorToApp = new HashMap[ActorRef, ApplicationInfo] + val addressToApp = new HashMap[Address, ApplicationInfo] - val waitingJobs = new ArrayBuffer[JobInfo] - val completedJobs = new ArrayBuffer[JobInfo] + val waitingApps = new ArrayBuffer[ApplicationInfo] + val completedApps = new ArrayBuffer[ApplicationInfo] + + var firstApp: Option[ApplicationInfo] = None val masterPublicAddress = { val envVar = System.getenv("SPARK_PUBLIC_DNS") @@ -37,15 +41,16 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor } // As a temporary workaround before better ways of configuring memory, we allow users to set - // a flag that will perform round-robin scheduling across the nodes (spreading out each job - // among all the nodes) instead of trying to consolidate each job onto a small # of nodes. - val spreadOutJobs = System.getProperty("spark.deploy.spreadOut", "false").toBoolean + // a flag that will perform round-robin scheduling across the nodes (spreading out each app + // among all the nodes) instead of trying to consolidate each app onto a small # of nodes. + val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "false").toBoolean override def preStart() { logInfo("Starting Spark master at spark://" + ip + ":" + port) // Listen for remote client disconnection events, since they don't go through Akka's watch() context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) startWebUi() + context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis)(timeOutDeadWorkers()) } def startWebUi() { @@ -73,92 +78,101 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor } } - case RegisterJob(description) => { - logInfo("Registering job " + description.name) - val job = addJob(description, sender) - logInfo("Registered job " + description.name + " with ID " + job.id) - waitingJobs += job + case RegisterApplication(description) => { + logInfo("Registering app " + description.name) + val app = addApplication(description, sender) + logInfo("Registered app " + description.name + " with ID " + app.id) + waitingApps += app context.watch(sender) // This doesn't work with remote actors but helps for testing - sender ! RegisteredJob(job.id) + sender ! RegisteredApplication(app.id) schedule() } - case ExecutorStateChanged(jobId, execId, state, message, exitStatus) => { - val execOption = idToJob.get(jobId).flatMap(job => job.executors.get(execId)) + case ExecutorStateChanged(appId, execId, state, message, exitStatus) => { + val execOption = idToApp.get(appId).flatMap(app => app.executors.get(execId)) execOption match { case Some(exec) => { exec.state = state - exec.job.driver ! ExecutorUpdated(execId, state, message, exitStatus) + exec.application.driver ! ExecutorUpdated(execId, state, message, exitStatus) if (ExecutorState.isFinished(state)) { - val jobInfo = idToJob(jobId) - // Remove this executor from the worker and job + val appInfo = idToApp(appId) + // Remove this executor from the worker and app logInfo("Removing executor " + exec.fullId + " because it is " + state) - jobInfo.removeExecutor(exec) + appInfo.removeExecutor(exec) exec.worker.removeExecutor(exec) // Only retry certain number of times so we don't go into an infinite loop. - if (jobInfo.incrementRetryCount < JobState.MAX_NUM_RETRY) { + if (appInfo.incrementRetryCount < ApplicationState.MAX_NUM_RETRY) { schedule() } else { - logError("Job %s with ID %s failed %d times, removing it".format( - jobInfo.desc.name, jobInfo.id, jobInfo.retryCount)) - removeJob(jobInfo) + logError("Application %s with ID %s failed %d times, removing it".format( + appInfo.desc.name, appInfo.id, appInfo.retryCount)) + removeApplication(appInfo) } } } case None => - logWarning("Got status update for unknown executor " + jobId + "/" + execId) + logWarning("Got status update for unknown executor " + appId + "/" + execId) + } + } + + case Heartbeat(workerId) => { + idToWorker.get(workerId) match { + case Some(workerInfo) => + workerInfo.lastHeartbeat = System.currentTimeMillis() + case None => + logWarning("Got heartbeat from unregistered worker " + workerId) } } case Terminated(actor) => { - // The disconnected actor could've been either a worker or a job; remove whichever of + // The disconnected actor could've been either a worker or an app; remove whichever of // those we have an entry for in the corresponding actor hashmap actorToWorker.get(actor).foreach(removeWorker) - actorToJob.get(actor).foreach(removeJob) + actorToApp.get(actor).foreach(removeApplication) } case RemoteClientDisconnected(transport, address) => { - // The disconnected client could've been either a worker or a job; remove whichever it was + // The disconnected client could've been either a worker or an app; remove whichever it was addressToWorker.get(address).foreach(removeWorker) - addressToJob.get(address).foreach(removeJob) + addressToApp.get(address).foreach(removeApplication) } case RemoteClientShutdown(transport, address) => { - // The disconnected client could've been either a worker or a job; remove whichever it was + // The disconnected client could've been either a worker or an app; remove whichever it was addressToWorker.get(address).foreach(removeWorker) - addressToJob.get(address).foreach(removeJob) + addressToApp.get(address).foreach(removeApplication) } case RequestMasterState => { - sender ! MasterState(ip + ":" + port, workers.toArray, jobs.toArray, completedJobs.toArray) + sender ! MasterState(ip, port, workers.toArray, apps.toArray, completedApps.toArray) } } /** - * Can a job use the given worker? True if the worker has enough memory and we haven't already - * launched an executor for the job on it (right now the standalone backend doesn't like having + * Can an app use the given worker? True if the worker has enough memory and we haven't already + * launched an executor for the app on it (right now the standalone backend doesn't like having * two executors on the same worker). */ - def canUse(job: JobInfo, worker: WorkerInfo): Boolean = { - worker.memoryFree >= job.desc.memoryPerSlave && !worker.hasExecutor(job) + def canUse(app: ApplicationInfo, worker: WorkerInfo): Boolean = { + worker.memoryFree >= app.desc.memoryPerSlave && !worker.hasExecutor(app) } /** - * Schedule the currently available resources among waiting jobs. This method will be called - * every time a new job joins or resource availability changes. + * Schedule the currently available resources among waiting apps. This method will be called + * every time a new app joins or resource availability changes. */ def schedule() { - // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first job - // in the queue, then the second job, etc. - if (spreadOutJobs) { - // Try to spread out each job among all the nodes, until it has all its cores - for (job <- waitingJobs if job.coresLeft > 0) { + // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app + // in the queue, then the second app, etc. + if (spreadOutApps) { + // Try to spread out each app among all the nodes, until it has all its cores + for (app <- waitingApps if app.coresLeft > 0) { val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) - .filter(canUse(job, _)).sortBy(_.coresFree).reverse + .filter(canUse(app, _)).sortBy(_.coresFree).reverse val numUsable = usableWorkers.length val assigned = new Array[Int](numUsable) // Number of cores to give on each node - var toAssign = math.min(job.coresLeft, usableWorkers.map(_.coresFree).sum) + var toAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum) var pos = 0 while (toAssign > 0) { if (usableWorkers(pos).coresFree - assigned(pos) > 0) { @@ -170,22 +184,22 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor // Now that we've decided how many cores to give on each node, let's actually give them for (pos <- 0 until numUsable) { if (assigned(pos) > 0) { - val exec = job.addExecutor(usableWorkers(pos), assigned(pos)) - launchExecutor(usableWorkers(pos), exec, job.desc.sparkHome) - job.state = JobState.RUNNING + val exec = app.addExecutor(usableWorkers(pos), assigned(pos)) + launchExecutor(usableWorkers(pos), exec, app.desc.sparkHome) + app.state = ApplicationState.RUNNING } } } } else { - // Pack each job into as few nodes as possible until we've assigned all its cores - for (worker <- workers if worker.coresFree > 0) { - for (job <- waitingJobs if job.coresLeft > 0) { - if (canUse(job, worker)) { - val coresToUse = math.min(worker.coresFree, job.coresLeft) + // Pack each app into as few nodes as possible until we've assigned all its cores + for (worker <- workers if worker.coresFree > 0 && worker.state == WorkerState.ALIVE) { + for (app <- waitingApps if app.coresLeft > 0) { + if (canUse(app, worker)) { + val coresToUse = math.min(worker.coresFree, app.coresLeft) if (coresToUse > 0) { - val exec = job.addExecutor(worker, coresToUse) - launchExecutor(worker, exec, job.desc.sparkHome) - job.state = JobState.RUNNING + val exec = app.addExecutor(worker, coresToUse) + launchExecutor(worker, exec, app.desc.sparkHome) + app.state = ApplicationState.RUNNING } } } @@ -196,8 +210,8 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(exec.job.id, exec.id, exec.job.desc, exec.cores, exec.memory, sparkHome) - exec.job.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) + worker.actor ! LaunchExecutor(exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome) + exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory) } def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int, @@ -219,45 +233,65 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor actorToWorker -= worker.actor addressToWorker -= worker.actor.path.address for (exec <- worker.executors.values) { - exec.job.driver ! ExecutorStateChanged(exec.job.id, exec.id, ExecutorState.LOST, None, None) - exec.job.executors -= exec.id + logInfo("Telling app of lost executor: " + exec.id) + exec.application.driver ! ExecutorUpdated(exec.id, ExecutorState.LOST, Some("worker lost"), None) + exec.application.removeExecutor(exec) } } - def addJob(desc: JobDescription, driver: ActorRef): JobInfo = { + def addApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = { val now = System.currentTimeMillis() val date = new Date(now) - val job = new JobInfo(now, newJobId(date), desc, date, driver) - jobs += job - idToJob(job.id) = job - actorToJob(driver) = job - addressToJob(driver.path.address) = job - return job + val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver) + apps += app + idToApp(app.id) = app + actorToApp(driver) = app + addressToApp(driver.path.address) = app + if (firstApp == None) { + firstApp = Some(app) + } + val workersAlive = workers.filter(_.state == WorkerState.ALIVE).toArray + if (workersAlive.size > 0 && !workersAlive.exists(_.memoryFree >= desc.memoryPerSlave)) { + logWarning("Could not find any workers with enough memory for " + firstApp.get.id) + } + return app } - def removeJob(job: JobInfo) { - if (jobs.contains(job)) { - logInfo("Removing job " + job.id) - jobs -= job - idToJob -= job.id - actorToJob -= job.driver - addressToWorker -= job.driver.path.address - completedJobs += job // Remember it in our history - waitingJobs -= job - for (exec <- job.executors.values) { + def removeApplication(app: ApplicationInfo) { + if (apps.contains(app)) { + logInfo("Removing app " + app.id) + apps -= app + idToApp -= app.id + actorToApp -= app.driver + addressToWorker -= app.driver.path.address + completedApps += app // Remember it in our history + waitingApps -= app + for (exec <- app.executors.values) { exec.worker.removeExecutor(exec) - exec.worker.actor ! KillExecutor(exec.job.id, exec.id) + exec.worker.actor ! KillExecutor(exec.application.id, exec.id) } - job.markFinished(JobState.FINISHED) // TODO: Mark it as FAILED if it failed + app.markFinished(ApplicationState.FINISHED) // TODO: Mark it as FAILED if it failed schedule() } } - /** Generate a new job ID given a job's submission date */ - def newJobId(submitDate: Date): String = { - val jobId = "job-%s-%04d".format(DATE_FORMAT.format(submitDate), nextJobNumber) - nextJobNumber += 1 - jobId + /** Generate a new app ID given a app's submission date */ + def newApplicationId(submitDate: Date): String = { + val appId = "app-%s-%04d".format(DATE_FORMAT.format(submitDate), nextAppNumber) + nextAppNumber += 1 + appId + } + + /** Check for, and remove, any timed-out workers */ + def timeOutDeadWorkers() { + // Copy the workers into an array so we don't modify the hashset while iterating through it + val expirationTime = System.currentTimeMillis() - WORKER_TIMEOUT + val toRemove = workers.filter(_.lastHeartbeat < expirationTime).toArray + for (worker <- toRemove) { + logWarning("Removing %s because we got no heartbeat in %d seconds".format( + worker.id, WORKER_TIMEOUT)) + removeWorker(worker) + } } } diff --git a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala index 529f72e9da1c2c618f23306d485d88a0fbd986d8..54faa375fbd468e4ea05ce3fda9a9e41dfcea166 100644 --- a/core/src/main/scala/spark/deploy/master/MasterWebUI.scala +++ b/core/src/main/scala/spark/deploy/master/MasterWebUI.scala @@ -40,27 +40,27 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct } } } ~ - path("job") { - parameters("jobId", 'format ?) { - case (jobId, Some(js)) if (js.equalsIgnoreCase("json")) => + path("app") { + parameters("appId", 'format ?) { + case (appId, Some(js)) if (js.equalsIgnoreCase("json")) => val future = master ? RequestMasterState - val jobInfo = for (masterState <- future.mapTo[MasterState]) yield { - masterState.activeJobs.find(_.id == jobId).getOrElse({ - masterState.completedJobs.find(_.id == jobId).getOrElse(null) + val appInfo = for (masterState <- future.mapTo[MasterState]) yield { + masterState.activeApps.find(_.id == appId).getOrElse({ + masterState.completedApps.find(_.id == appId).getOrElse(null) }) } respondWithMediaType(MediaTypes.`application/json`) { ctx => - ctx.complete(jobInfo.mapTo[JobInfo]) + ctx.complete(appInfo.mapTo[ApplicationInfo]) } - case (jobId, _) => + case (appId, _) => completeWith { val future = master ? RequestMasterState future.map { state => val masterState = state.asInstanceOf[MasterState] - val job = masterState.activeJobs.find(_.id == jobId).getOrElse({ - masterState.completedJobs.find(_.id == jobId).getOrElse(null) + val app = masterState.activeApps.find(_.id == appId).getOrElse({ + masterState.completedApps.find(_.id == appId).getOrElse(null) }) - spark.deploy.master.html.job_details.render(job) + spark.deploy.master.html.app_details.render(app) } } } diff --git a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala index 5a7f5fef8a546812d20719af8d2a0f3dcab1af29..23df1bb463288721e38379023b3fd01c4c8632d8 100644 --- a/core/src/main/scala/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/spark/deploy/master/WorkerInfo.scala @@ -18,6 +18,8 @@ private[spark] class WorkerInfo( var coresUsed = 0 var memoryUsed = 0 + var lastHeartbeat = System.currentTimeMillis() + def coresFree: Int = cores - coresUsed def memoryFree: Int = memory - memoryUsed @@ -35,8 +37,8 @@ private[spark] class WorkerInfo( } } - def hasExecutor(job: JobInfo): Boolean = { - executors.values.exists(_.job == job) + def hasExecutor(app: ApplicationInfo): Boolean = { + executors.values.exists(_.application == app) } def webUiAddress : String = { diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index 4ef637090c4ebbe5f63d9ce599133009c7c06de8..de11771c8e62d7cdb0b629a105bb8d1ca21e544c 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -1,7 +1,7 @@ package spark.deploy.worker import java.io._ -import spark.deploy.{ExecutorState, ExecutorStateChanged, JobDescription} +import spark.deploy.{ExecutorState, ExecutorStateChanged, ApplicationDescription} import akka.actor.ActorRef import spark.{Utils, Logging} import java.net.{URI, URL} @@ -14,9 +14,9 @@ import spark.deploy.ExecutorStateChanged * Manages the execution of one executor process. */ private[spark] class ExecutorRunner( - val jobId: String, + val appId: String, val execId: Int, - val jobDesc: JobDescription, + val appDesc: ApplicationDescription, val cores: Int, val memory: Int, val worker: ActorRef, @@ -26,7 +26,7 @@ private[spark] class ExecutorRunner( val workDir: File) extends Logging { - val fullId = jobId + "/" + execId + val fullId = appId + "/" + execId var workerThread: Thread = null var process: Process = null var shutdownHook: Thread = null @@ -60,7 +60,7 @@ private[spark] class ExecutorRunner( process.destroy() process.waitFor() } - worker ! ExecutorStateChanged(jobId, execId, ExecutorState.KILLED, None, None) + worker ! ExecutorStateChanged(appId, execId, ExecutorState.KILLED, None, None) Runtime.getRuntime.removeShutdownHook(shutdownHook) } } @@ -74,10 +74,10 @@ private[spark] class ExecutorRunner( } def buildCommandSeq(): Seq[String] = { - val command = jobDesc.command - val script = if (System.getProperty("os.name").startsWith("Windows")) "run.cmd" else "run"; + val command = appDesc.command + val script = if (System.getProperty("os.name").startsWith("Windows")) "run.cmd" else "run" val runScript = new File(sparkHome, script).getCanonicalPath - Seq(runScript, command.mainClass) ++ command.arguments.map(substituteVariables) + Seq(runScript, command.mainClass) ++ (command.arguments ++ Seq(appId)).map(substituteVariables) } /** Spawn a thread that will redirect a given stream to a file */ @@ -96,12 +96,12 @@ private[spark] class ExecutorRunner( } /** - * Download and run the executor described in our JobDescription + * Download and run the executor described in our ApplicationDescription */ def fetchAndRunExecutor() { try { // Create the executor's working directory - val executorDir = new File(workDir, jobId + "/" + execId) + val executorDir = new File(workDir, appId + "/" + execId) if (!executorDir.mkdirs()) { throw new IOException("Failed to create directory " + executorDir) } @@ -110,7 +110,7 @@ private[spark] class ExecutorRunner( val command = buildCommandSeq() val builder = new ProcessBuilder(command: _*).directory(executorDir) val env = builder.environment() - for ((key, value) <- jobDesc.command.environment) { + for ((key, value) <- appDesc.command.environment) { env.put(key, value) } env.put("SPARK_MEM", memory.toString + "m") @@ -128,7 +128,7 @@ private[spark] class ExecutorRunner( // times on the same machine. val exitCode = process.waitFor() val message = "Command exited with code " + exitCode - worker ! ExecutorStateChanged(jobId, execId, ExecutorState.FAILED, Some(message), + worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message), Some(exitCode)) } catch { case interrupted: InterruptedException => @@ -140,7 +140,7 @@ private[spark] class ExecutorRunner( process.destroy() } val message = e.getClass + ": " + e.getMessage - worker ! ExecutorStateChanged(jobId, execId, ExecutorState.FAILED, Some(message), None) + worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message), None) } } } diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 38547ec4f1dffa6ee883fa1cc947751e22f0b6b4..2bbc931316291fae8c911b3c24ca3105a0f799ec 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -2,6 +2,7 @@ package spark.deploy.worker import scala.collection.mutable.{ArrayBuffer, HashMap} import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated} +import akka.util.duration._ import spark.{Logging, Utils} import spark.util.AkkaUtils import spark.deploy._ @@ -26,6 +27,9 @@ private[spark] class Worker( val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs + // Send a heartbeat every (heartbeat timeout) / 4 milliseconds + val HEARTBEAT_MILLIS = System.getProperty("spark.worker.timeout", "60").toLong * 1000 / 4 + var master: ActorRef = null var masterWebUiUrl : String = "" val workerId = generateWorkerId() @@ -97,24 +101,27 @@ private[spark] class Worker( case RegisteredWorker(url) => masterWebUiUrl = url logInfo("Successfully registered with master") + context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis) { + master ! Heartbeat(workerId) + } case RegisterWorkerFailed(message) => logError("Worker registration failed: " + message) System.exit(1) - case LaunchExecutor(jobId, execId, jobDesc, cores_, memory_, execSparkHome_) => - logInfo("Asked to launch executor %s/%d for %s".format(jobId, execId, jobDesc.name)) + case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) => + logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) val manager = new ExecutorRunner( - jobId, execId, jobDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir) - executors(jobId + "/" + execId) = manager + appId, execId, appDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir) + executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ memoryUsed += memory_ - master ! ExecutorStateChanged(jobId, execId, ExecutorState.RUNNING, None, None) + master ! ExecutorStateChanged(appId, execId, ExecutorState.RUNNING, None, None) - case ExecutorStateChanged(jobId, execId, state, message, exitStatus) => - master ! ExecutorStateChanged(jobId, execId, state, message, exitStatus) - val fullId = jobId + "/" + execId + case ExecutorStateChanged(appId, execId, state, message, exitStatus) => + master ! ExecutorStateChanged(appId, execId, state, message, exitStatus) + val fullId = appId + "/" + execId if (ExecutorState.isFinished(state)) { val executor = executors(fullId) logInfo("Executor " + fullId + " finished with state " + state + @@ -126,8 +133,8 @@ private[spark] class Worker( memoryUsed -= executor.memory } - case KillExecutor(jobId, execId) => - val fullId = jobId + "/" + execId + case KillExecutor(appId, execId) => + val fullId = appId + "/" + execId executors.get(fullId) match { case Some(executor) => logInfo("Asked to kill executor " + fullId) @@ -140,7 +147,7 @@ private[spark] class Worker( masterDisconnected() case RequestWorkerState => { - sender ! WorkerState(ip + ":" + port, workerId, executors.values.toList, + sender ! WorkerState(ip, port, workerId, executors.values.toList, finishedExecutors.values.toList, masterUrl, cores, memory, coresUsed, memoryUsed, masterWebUiUrl) } diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala index 37524a7c82f8b347d7fbe5f571a9647ff366bb68..08f02bad80d7f47b2f019745216538eda2640223 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala @@ -92,7 +92,7 @@ private[spark] class WorkerArguments(args: Array[String]) { "Options:\n" + " -c CORES, --cores CORES Number of cores to use\n" + " -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" + - " -d DIR, --work-dir DIR Directory to run jobs in (default: SPARK_HOME/work)\n" + + " -d DIR, --work-dir DIR Directory to run apps in (default: SPARK_HOME/work)\n" + " -i IP, --ip IP IP address or DNS name to listen on\n" + " -p PORT, --port PORT Port to listen on (default: random)\n" + " --webui-port PORT Port for web UI (default: 8081)") diff --git a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala index ef81f072a308272ae1b1dbb6f70eb0ba794df5b4..135cc2e86cc9267661775a915f57f2fe470dec00 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerWebUI.scala @@ -41,9 +41,9 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct } } ~ path("log") { - parameters("jobId", "executorId", "logType") { (jobId, executorId, logType) => + parameters("appId", "executorId", "logType") { (appId, executorId, logType) => respondWithMediaType(cc.spray.http.MediaTypes.`text/plain`) { - getFromFileName("work/" + jobId + "/" + executorId + "/" + logType) + getFromFileName("work/" + appId + "/" + executorId + "/" + logType) } } } ~ diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala index bd21ba719a77cf2eaca6b2f70002f191daca4653..5de09030aa1b3b4318c1eb6051eed3e1a3fb23ef 100644 --- a/core/src/main/scala/spark/executor/Executor.scala +++ b/core/src/main/scala/spark/executor/Executor.scala @@ -50,14 +50,19 @@ private[spark] class Executor extends Logging { override def uncaughtException(thread: Thread, exception: Throwable) { try { logError("Uncaught exception in thread " + thread, exception) - if (exception.isInstanceOf[OutOfMemoryError]) { - System.exit(ExecutorExitCode.OOM) - } else { - System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION) + + // We may have been called from a shutdown hook. If so, we must not call System.exit(). + // (If we do, we will deadlock.) + if (!Utils.inShutdown()) { + if (exception.isInstanceOf[OutOfMemoryError]) { + System.exit(ExecutorExitCode.OOM) + } else { + System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION) + } } } catch { - case oom: OutOfMemoryError => System.exit(ExecutorExitCode.OOM) - case t: Throwable => System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE) + case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM) + case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE) } } } diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala index 224c126fdd1eec87d34d98c33682c79ace097438..9a82c3054c0fcf3ff6d0ecf5dc36c54d658fdaa0 100644 --- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala @@ -68,8 +68,9 @@ private[spark] object StandaloneExecutorBackend { } def main(args: Array[String]) { - if (args.length != 4) { - System.err.println("Usage: StandaloneExecutorBackend <driverUrl> <executorId> <hostname> <cores>") + if (args.length < 4) { + //the reason we allow the last frameworkId argument is to make it easy to kill rogue executors + System.err.println("Usage: StandaloneExecutorBackend <driverUrl> <executorId> <hostname> <cores> [<appid>]") System.exit(1) } run(args(0), args(1), args(2), args(3).toInt) diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index cd5b7d57f32f58c3c1a5e20e1dcc011d0650a7e1..d1451bc2124c581eff01dfae5277612ea5c995c7 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -198,7 +198,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { outbox.synchronized { outbox.addMessage(message) if (channel.isConnected) { - changeConnectionKeyInterest(SelectionKey.OP_WRITE) + changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ) } } } @@ -219,7 +219,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { def finishConnect() { try { channel.finishConnect - changeConnectionKeyInterest(SelectionKey.OP_WRITE) + changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ) logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") } catch { case e: Exception => { @@ -239,8 +239,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { currentBuffers ++= chunk.buffers } case None => { - changeConnectionKeyInterest(0) - /*key.interestOps(0)*/ + changeConnectionKeyInterest(SelectionKey.OP_READ) return } } @@ -267,6 +266,23 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } } } + + override def read() { + // We don't expect the other side to send anything; so, we just read to detect an error or EOF. + try { + val length = channel.read(ByteBuffer.allocate(1)) + if (length == -1) { // EOF + close() + } else if (length > 0) { + logWarning("Unexpected data read from SendingConnection to " + remoteConnectionManagerId) + } + } catch { + case e: Exception => + logError("Exception while reading SendingConnection to " + remoteConnectionManagerId, e) + callOnExceptionCallback(e) + close() + } + } } diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index c7f226044d1e9a9a11db7487a64c31c51402af46..b6ec664d7e81bb581f1be553b3b1dfaafca61865 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -66,31 +66,28 @@ private[spark] class ConnectionManager(port: Int) extends Logging { val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id) - val thisInstance = this val selectorThread = new Thread("connection-manager-thread") { - override def run() { - thisInstance.run() - } + override def run() = ConnectionManager.this.run() } selectorThread.setDaemon(true) selectorThread.start() - def run() { + private def run() { try { while(!selectorThread.isInterrupted) { - for( (connectionManagerId, sendingConnection) <- connectionRequests) { + for ((connectionManagerId, sendingConnection) <- connectionRequests) { sendingConnection.connect() addConnection(sendingConnection) connectionRequests -= connectionManagerId } sendMessageRequests.synchronized { - while(!sendMessageRequests.isEmpty) { + while (!sendMessageRequests.isEmpty) { val (message, connection) = sendMessageRequests.dequeue connection.send(message) } } - while(!keyInterestChangeRequests.isEmpty) { + while (!keyInterestChangeRequests.isEmpty) { val (key, ops) = keyInterestChangeRequests.dequeue val connection = connectionsByKey(key) val lastOps = key.interestOps() @@ -126,14 +123,11 @@ private[spark] class ConnectionManager(port: Int) extends Logging { if (key.isValid) { if (key.isAcceptable) { acceptConnection(key) - } else - if (key.isConnectable) { + } else if (key.isConnectable) { connectionsByKey(key).asInstanceOf[SendingConnection].finishConnect() - } else - if (key.isReadable) { + } else if (key.isReadable) { connectionsByKey(key).read() - } else - if (key.isWritable) { + } else if (key.isWritable) { connectionsByKey(key).write() } } @@ -144,7 +138,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { } } - def acceptConnection(key: SelectionKey) { + private def acceptConnection(key: SelectionKey) { val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] val newChannel = serverChannel.accept() val newConnection = new ReceivingConnection(newChannel, selector) @@ -154,7 +148,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]") } - def addConnection(connection: Connection) { + private def addConnection(connection: Connection) { connectionsByKey += ((connection.key, connection)) if (connection.isInstanceOf[SendingConnection]) { val sendingConnection = connection.asInstanceOf[SendingConnection] @@ -165,7 +159,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging { connection.onClose(removeConnection) } - def removeConnection(connection: Connection) { + private def removeConnection(connection: Connection) { connectionsByKey -= connection.key if (connection.isInstanceOf[SendingConnection]) { val sendingConnection = connection.asInstanceOf[SendingConnection] @@ -222,16 +216,16 @@ private[spark] class ConnectionManager(port: Int) extends Logging { } } - def handleConnectionError(connection: Connection, e: Exception) { + private def handleConnectionError(connection: Connection, e: Exception) { logInfo("Handling connection error on connection to " + connection.remoteConnectionManagerId) removeConnection(connection) } - def changeConnectionKeyInterest(connection: Connection, ops: Int) { + private def changeConnectionKeyInterest(connection: Connection, ops: Int) { keyInterestChangeRequests += ((connection.key, ops)) } - def receiveMessage(connection: Connection, message: Message) { + private def receiveMessage(connection: Connection, message: Message) { val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress) logDebug("Received [" + message + "] from [" + connectionManagerId + "]") val runnable = new Runnable() { @@ -351,7 +345,6 @@ private[spark] class ConnectionManager(port: Int) extends Logging { private[spark] object ConnectionManager { def main(args: Array[String]) { - val manager = new ConnectionManager(9999) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { println("Received [" + msg + "] from [" + id + "]") diff --git a/core/src/main/scala/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/spark/partial/ApproximateActionListener.scala index 24b4909380c66fb6f49dee5721e60229ebd8ba76..de2dce161a0906f727f563674220f061dcb6297d 100644 --- a/core/src/main/scala/spark/partial/ApproximateActionListener.scala +++ b/core/src/main/scala/spark/partial/ApproximateActionListener.scala @@ -20,7 +20,7 @@ private[spark] class ApproximateActionListener[T, U, R]( extends JobListener { val startTime = System.currentTimeMillis() - val totalTasks = rdd.splits.size + val totalTasks = rdd.partitions.size var finishedTasks = 0 var failure: Option[Exception] = None // Set if the job has failed (permanently) var resultObject: Option[PartialResult[R]] = None // Set if we've already returned a PartialResult diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index 2c022f88e0defb7445463cb78e4ee7c57042d68d..7348c4f15bad60a94e736a2e10cec2a018233e99 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -1,9 +1,9 @@ package spark.rdd import scala.collection.mutable.HashMap -import spark.{RDD, SparkContext, SparkEnv, Split, TaskContext} +import spark.{RDD, SparkContext, SparkEnv, Partition, TaskContext} -private[spark] class BlockRDDSplit(val blockId: String, idx: Int) extends Split { +private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition { val index = idx } @@ -11,10 +11,6 @@ private[spark] class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String]) extends RDD[T](sc, Nil) { - @transient var splits_ : Array[Split] = (0 until blockIds.size).map(i => { - new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split] - }).toArray - @transient lazy val locations_ = { val blockManager = SparkEnv.get.blockManager /*val locations = blockIds.map(id => blockManager.getLocations(id))*/ @@ -22,11 +18,14 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St HashMap(blockIds.zip(locations):_*) } - override def getSplits = splits_ + override def getPartitions: Array[Partition] = (0 until blockIds.size).map(i => { + new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition] + }).toArray - override def compute(split: Split, context: TaskContext): Iterator[T] = { + + override def compute(split: Partition, context: TaskContext): Iterator[T] = { val blockManager = SparkEnv.get.blockManager - val blockId = split.asInstanceOf[BlockRDDSplit].blockId + val blockId = split.asInstanceOf[BlockRDDPartition].blockId blockManager.get(blockId) match { case Some(block) => block.asInstanceOf[Iterator[T]] case None => @@ -34,11 +33,8 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St } } - override def getPreferredLocations(split: Split) = - locations_(split.asInstanceOf[BlockRDDSplit].blockId) + override def getPreferredLocations(split: Partition): Seq[String] = + locations_(split.asInstanceOf[BlockRDDPartition].blockId) - override def clearDependencies() { - splits_ = null - } } diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 0f9ca0653198f35c25122c0dedfe22ef93849762..38600b8be4e544206ef4aedebfcb79d72efdb187 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -5,22 +5,22 @@ import spark._ private[spark] -class CartesianSplit( +class CartesianPartition( idx: Int, @transient rdd1: RDD[_], @transient rdd2: RDD[_], s1Index: Int, s2Index: Int - ) extends Split { - var s1 = rdd1.splits(s1Index) - var s2 = rdd2.splits(s2Index) + ) extends Partition { + var s1 = rdd1.partitions(s1Index) + var s2 = rdd2.partitions(s2Index) override val index: Int = idx @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { // Update the reference to parent split at the time of task serialization - s1 = rdd1.splits(s1Index) - s2 = rdd2.splits(s2Index) + s1 = rdd1.partitions(s1Index) + s2 = rdd2.partitions(s2Index) oos.defaultWriteObject() } } @@ -33,39 +33,40 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( extends RDD[Pair[T, U]](sc, Nil) with Serializable { - val numSplitsInRdd2 = rdd2.splits.size + val numPartitionsInRdd2 = rdd2.partitions.size - override def getSplits: Array[Split] = { + override def getPartitions: Array[Partition] = { // create the cross product split - val array = new Array[Split](rdd1.splits.size * rdd2.splits.size) - for (s1 <- rdd1.splits; s2 <- rdd2.splits) { - val idx = s1.index * numSplitsInRdd2 + s2.index - array(idx) = new CartesianSplit(idx, rdd1, rdd2, s1.index, s2.index) + val array = new Array[Partition](rdd1.partitions.size * rdd2.partitions.size) + for (s1 <- rdd1.partitions; s2 <- rdd2.partitions) { + val idx = s1.index * numPartitionsInRdd2 + s2.index + array(idx) = new CartesianPartition(idx, rdd1, rdd2, s1.index, s2.index) } array } - override def getPreferredLocations(split: Split) = { - val currSplit = split.asInstanceOf[CartesianSplit] + override def getPreferredLocations(split: Partition): Seq[String] = { + val currSplit = split.asInstanceOf[CartesianPartition] rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) } - override def compute(split: Split, context: TaskContext) = { - val currSplit = split.asInstanceOf[CartesianSplit] + override def compute(split: Partition, context: TaskContext) = { + val currSplit = split.asInstanceOf[CartesianPartition] for (x <- rdd1.iterator(currSplit.s1, context); y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) } override def getDependencies: Seq[Dependency[_]] = List( new NarrowDependency(rdd1) { - def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2) + def getParents(id: Int): Seq[Int] = List(id / numPartitionsInRdd2) }, new NarrowDependency(rdd2) { - def getParents(id: Int): Seq[Int] = List(id % numSplitsInRdd2) + def getParents(id: Int): Seq[Int] = List(id % numPartitionsInRdd2) } ) override def clearDependencies() { + super.clearDependencies() rdd1 = null rdd2 = null } diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala index 96b593ba7ca6d282c7158d9049b054afc8a27ea5..9e37bdf659201c2ec7be7dd84de83881d34d8704 100644 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -9,7 +9,7 @@ import org.apache.hadoop.fs.Path import java.io.{File, IOException, EOFException} import java.text.NumberFormat -private[spark] class CheckpointRDDSplit(val index: Int) extends Split {} +private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} /** * This RDD represents a RDD checkpoint file (similar to HadoopRDD). @@ -20,29 +20,27 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: Stri @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) - @transient val splits_ : Array[Split] = { + override def getPartitions: Array[Partition] = { val dirContents = fs.listStatus(new Path(checkpointPath)) - val splitFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted - val numSplits = splitFiles.size - if (!splitFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || - !splitFiles(numSplits-1).endsWith(CheckpointRDD.splitIdToFile(numSplits-1))) { + val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted + val numPartitions = partitionFiles.size + if (numPartitions > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) || + ! partitionFiles(numPartitions-1).endsWith(CheckpointRDD.splitIdToFile(numPartitions-1)))) { throw new SparkException("Invalid checkpoint directory: " + checkpointPath) } - Array.tabulate(numSplits)(i => new CheckpointRDDSplit(i)) + Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i)) } checkpointData = Some(new RDDCheckpointData[T](this)) checkpointData.get.cpFile = Some(checkpointPath) - override def getSplits = splits_ - - override def getPreferredLocations(split: Split): Seq[String] = { + override def getPreferredLocations(split: Partition): Seq[String] = { val status = fs.getFileStatus(new Path(checkpointPath)) val locations = fs.getFileBlockLocations(status, 0, status.getLen) locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost") } - override def compute(split: Split, context: TaskContext): Iterator[T] = { + override def compute(split: Partition, context: TaskContext): Iterator[T] = { val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index)) CheckpointRDD.readFromFile(file, context) } @@ -109,7 +107,7 @@ private[spark] object CheckpointRDD extends Logging { deserializeStream.asIterator.asInstanceOf[Iterator[T]] } - // Test whether CheckpointRDD generate expected number of splits despite + // Test whether CheckpointRDD generate expected number of partitions despite // each split file having multiple blocks. This needs to be run on a // cluster (mesos or standalone) using HDFS. def main(args: Array[String]) { @@ -122,8 +120,8 @@ private[spark] object CheckpointRDD extends Logging { val fs = path.getFileSystem(new Configuration()) sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _) val cpRDD = new CheckpointRDD[Int](sc, path.toString) - assert(cpRDD.splits.length == rdd.splits.length, "Number of splits is not the same") - assert(cpRDD.collect.toList == rdd.collect.toList, "Data of splits not the same") + assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same") + assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same") fs.delete(path) } } diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index 4893fe8d784bf14cba83d55298c5bd533083bbda..5200fb6b656ade45d1a8ca682f9f736ce1d769c5 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -5,7 +5,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.JavaConversions import scala.collection.mutable.ArrayBuffer -import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext} +import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Partition, TaskContext} import spark.{Dependency, OneToOneDependency, ShuffleDependency} @@ -14,13 +14,13 @@ private[spark] sealed trait CoGroupSplitDep extends Serializable private[spark] case class NarrowCoGroupSplitDep( rdd: RDD[_], splitIndex: Int, - var split: Split + var split: Partition ) extends CoGroupSplitDep { @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { // Update the reference to parent split at the time of task serialization - split = rdd.splits(splitIndex) + split = rdd.partitions(splitIndex) oos.defaultWriteObject() } } @@ -28,7 +28,7 @@ private[spark] case class NarrowCoGroupSplitDep( private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep private[spark] -class CoGroupSplit(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Split with Serializable { +class CoGroupPartition(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Partition with Serializable { override val index: Int = idx override def hashCode(): Int = idx } @@ -40,49 +40,45 @@ private[spark] class CoGroupAggregator { (b1, b2) => b1 ++ b2 }) with Serializable -class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) - extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) with Logging { +class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner) + extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) { - val aggr = new CoGroupAggregator + private val aggr = new CoGroupAggregator - @transient var deps_ = { - val deps = new ArrayBuffer[Dependency[_]] - for ((rdd, index) <- rdds.zipWithIndex) { + override def getDependencies: Seq[Dependency[_]] = { + rdds.map { rdd => if (rdd.partitioner == Some(part)) { logInfo("Adding one-to-one dependency with " + rdd) - deps += new OneToOneDependency(rdd) + new OneToOneDependency(rdd) } else { logInfo("Adding shuffle dependency with " + rdd) val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true) - deps += new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part) + new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part) } } - deps.toList } - override def getDependencies = deps_ - - @transient var splits_ : Array[Split] = { - val array = new Array[Split](part.numPartitions) + override def getPartitions: Array[Partition] = { + val array = new Array[Partition](part.numPartitions) for (i <- 0 until array.size) { - array(i) = new CoGroupSplit(i, rdds.zipWithIndex.map { case (r, j) => + // Each CoGroupPartition will have a dependency per contributing RDD + array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) => + // Assume each RDD contributed a single dependency, and get it dependencies(j) match { case s: ShuffleDependency[_, _] => - new ShuffleCoGroupSplitDep(s.shuffleId): CoGroupSplitDep + new ShuffleCoGroupSplitDep(s.shuffleId) case _ => - new NarrowCoGroupSplitDep(r, i, r.splits(i)): CoGroupSplitDep + new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) } }.toList) } array } - override def getSplits = splits_ - override val partitioner = Some(part) - override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { - val split = s.asInstanceOf[CoGroupSplit] + override def compute(s: Partition, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { + val split = s.asInstanceOf[CoGroupPartition] val numRdds = split.deps.size // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs) val map = new JHashMap[K, Seq[ArrayBuffer[Any]]] @@ -97,7 +93,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) } } for ((dep, depNum) <- split.deps.zipWithIndex) dep match { - case NarrowCoGroupSplitDep(rdd, itsSplitIndex, itsSplit) => { + case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { // Read them from the parent for ((k, v) <- rdd.iterator(itsSplit, context)) { getSeq(k.asInstanceOf[K])(depNum) += v @@ -115,8 +111,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) } override def clearDependencies() { - deps_ = null - splits_ = null + super.clearDependencies() rdds = null } } diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 4c57434b65f9a2f8fcafdcfa8225d9297cefbe16..0d16cf6e85459156392ed2d168a7a5a06d4acf0a 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -1,19 +1,19 @@ package spark.rdd -import spark.{Dependency, OneToOneDependency, NarrowDependency, RDD, Split, TaskContext} +import spark.{Dependency, OneToOneDependency, NarrowDependency, RDD, Partition, TaskContext} import java.io.{ObjectOutputStream, IOException} -private[spark] case class CoalescedRDDSplit( +private[spark] case class CoalescedRDDPartition( index: Int, @transient rdd: RDD[_], parentsIndices: Array[Int] - ) extends Split { - var parents: Seq[Split] = parentsIndices.map(rdd.splits(_)) + ) extends Partition { + var parents: Seq[Partition] = parentsIndices.map(rdd.partitions(_)) @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { // Update the reference to parent split at the time of task serialization - parents = parentsIndices.map(rdd.splits(_)) + parents = parentsIndices.map(rdd.partitions(_)) oos.defaultWriteObject() } } @@ -31,33 +31,34 @@ class CoalescedRDD[T: ClassManifest]( maxPartitions: Int) extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies - override def getSplits: Array[Split] = { - val prevSplits = prev.splits + override def getPartitions: Array[Partition] = { + val prevSplits = prev.partitions if (prevSplits.length < maxPartitions) { - prevSplits.map(_.index).map{idx => new CoalescedRDDSplit(idx, prev, Array(idx)) } + prevSplits.map(_.index).map{idx => new CoalescedRDDPartition(idx, prev, Array(idx)) } } else { (0 until maxPartitions).map { i => val rangeStart = (i * prevSplits.length) / maxPartitions val rangeEnd = ((i + 1) * prevSplits.length) / maxPartitions - new CoalescedRDDSplit(i, prev, (rangeStart until rangeEnd).toArray) + new CoalescedRDDPartition(i, prev, (rangeStart until rangeEnd).toArray) }.toArray } } - override def compute(split: Split, context: TaskContext): Iterator[T] = { - split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { parentSplit => + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + split.asInstanceOf[CoalescedRDDPartition].parents.iterator.flatMap { parentSplit => firstParent[T].iterator(parentSplit, context) } } - override def getDependencies: Seq[Dependency[_]] = List( - new NarrowDependency(prev) { + override def getDependencies: Seq[Dependency[_]] = { + Seq(new NarrowDependency(prev) { def getParents(id: Int): Seq[Int] = - splits(id).asInstanceOf[CoalescedRDDSplit].parentsIndices - } - ) + partitions(id).asInstanceOf[CoalescedRDDPartition].parentsIndices + }) + } override def clearDependencies() { + super.clearDependencies() prev = null } } diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala index 6dbe235bd9f15b304024354867b20561cb71f74e..c84ec39d21ff7e8f8414a08920a52c285b689d8e 100644 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -1,16 +1,16 @@ package spark.rdd -import spark.{OneToOneDependency, RDD, Split, TaskContext} +import spark.{OneToOneDependency, RDD, Partition, TaskContext} private[spark] class FilteredRDD[T: ClassManifest]( prev: RDD[T], f: T => Boolean) extends RDD[T](prev) { - override def getSplits = firstParent[T].splits + override def getPartitions: Array[Partition] = firstParent[T].partitions override val partitioner = prev.partitioner // Since filter cannot change a partition's keys - override def compute(split: Split, context: TaskContext) = + override def compute(split: Partition, context: TaskContext) = firstParent[T].iterator(split, context).filter(f) } diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala index 1b604c66e2fa52c849c38b1f16738b0949a7b405..8ebc77892514c86d0eec4173daf770ebe5f75b95 100644 --- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala +++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala @@ -1,6 +1,6 @@ package spark.rdd -import spark.{RDD, Split, TaskContext} +import spark.{RDD, Partition, TaskContext} private[spark] @@ -9,8 +9,8 @@ class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( f: T => TraversableOnce[U]) extends RDD[U](prev) { - override def getSplits = firstParent[T].splits + override def getPartitions: Array[Partition] = firstParent[T].partitions - override def compute(split: Split, context: TaskContext) = + override def compute(split: Partition, context: TaskContext) = firstParent[T].iterator(split, context).flatMap(f) } diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala index 051bffed192bc39a9ac3084ffbec49acdb524340..e16c7ba881977a9f7bd2a593f122d22a596d88c4 100644 --- a/core/src/main/scala/spark/rdd/GlommedRDD.scala +++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala @@ -1,12 +1,12 @@ package spark.rdd -import spark.{RDD, Split, TaskContext} +import spark.{RDD, Partition, TaskContext} private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev) { - override def getSplits = firstParent[T].splits + override def getPartitions: Array[Partition] = firstParent[T].partitions - override def compute(split: Split, context: TaskContext) = + override def compute(split: Partition, context: TaskContext) = Array(firstParent[T].iterator(split, context).toArray).iterator } diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index f547f53812661da253353384e43dbe2702a3cb68..78097502bca48d207a65563718079c638490d987 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -15,14 +15,14 @@ import org.apache.hadoop.mapred.RecordReader import org.apache.hadoop.mapred.Reporter import org.apache.hadoop.util.ReflectionUtils -import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskContext} +import spark.{Dependency, Logging, Partition, RDD, SerializableWritable, SparkContext, TaskContext} /** * A Spark split class that wraps around a Hadoop InputSplit. */ -private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit) - extends Split { +private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSplit) + extends Partition { val inputSplit = new SerializableWritable[InputSplit](s) @@ -42,18 +42,17 @@ class HadoopRDD[K, V]( keyClass: Class[K], valueClass: Class[V], minSplits: Int) - extends RDD[(K, V)](sc, Nil) { + extends RDD[(K, V)](sc, Nil) with Logging { // A Hadoop JobConf can be about 10 KB, which is pretty big, so broadcast it - val confBroadcast = sc.broadcast(new SerializableWritable(conf)) + private val confBroadcast = sc.broadcast(new SerializableWritable(conf)) - @transient - val splits_ : Array[Split] = { + override def getPartitions: Array[Partition] = { val inputFormat = createInputFormat(conf) val inputSplits = inputFormat.getSplits(conf, minSplits) - val array = new Array[Split](inputSplits.size) + val array = new Array[Partition](inputSplits.size) for (i <- 0 until inputSplits.size) { - array(i) = new HadoopSplit(id, i, inputSplits(i)) + array(i) = new HadoopPartition(id, i, inputSplits(i)) } array } @@ -63,10 +62,8 @@ class HadoopRDD[K, V]( .asInstanceOf[InputFormat[K, V]] } - override def getSplits = splits_ - - override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] { - val split = theSplit.asInstanceOf[HadoopSplit] + override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] { + val split = theSplit.asInstanceOf[HadoopPartition] var reader: RecordReader[K, V] = null val conf = confBroadcast.value.value @@ -74,7 +71,7 @@ class HadoopRDD[K, V]( reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL) // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback(() => reader.close()) + context.addOnCompleteCallback{ () => close() } val key: K = reader.createKey() val value: V = reader.createValue() @@ -91,9 +88,6 @@ class HadoopRDD[K, V]( } gotNext = true } - if (finished) { - reader.close() - } !finished } @@ -107,11 +101,19 @@ class HadoopRDD[K, V]( gotNext = false (key, value) } + + private def close() { + try { + reader.close() + } catch { + case e: Exception => logWarning("Exception in RecordReader.close()", e) + } + } } - override def getPreferredLocations(split: Split) = { + override def getPreferredLocations(split: Partition): Seq[String] = { // TODO: Filtering out "localhost" in case of file:// URLs - val hadoopSplit = split.asInstanceOf[HadoopSplit] + val hadoopSplit = split.asInstanceOf[HadoopPartition] hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") } diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala index 073f7d7d2aad251c4240bd665b9fc02e90eec8a8..d283c5b2bb8dfc79c20240476ca9271020fbc480 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala @@ -1,6 +1,6 @@ package spark.rdd -import spark.{RDD, Split, TaskContext} +import spark.{RDD, Partition, TaskContext} private[spark] @@ -13,8 +13,8 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None - override def getSplits = firstParent[T].splits + override def getPartitions: Array[Partition] = firstParent[T].partitions - override def compute(split: Split, context: TaskContext) = + override def compute(split: Partition, context: TaskContext) = f(firstParent[T].iterator(split, context)) -} \ No newline at end of file +} diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithIndexRDD.scala similarity index 57% rename from core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala rename to core/src/main/scala/spark/rdd/MapPartitionsWithIndexRDD.scala index 2ddc3d01b647a573d85aa0c3622341fdd3ed1adb..afb7504ba120dbc4a42beef8b001dbfc7cbb722f 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsWithIndexRDD.scala @@ -1,24 +1,24 @@ package spark.rdd -import spark.{RDD, Split, TaskContext} +import spark.{RDD, Partition, TaskContext} /** - * A variant of the MapPartitionsRDD that passes the split index into the + * A variant of the MapPartitionsRDD that passes the partition index into the * closure. This can be used to generate or collect partition specific * information such as the number of tuples in a partition. */ private[spark] -class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( +class MapPartitionsWithIndexRDD[U: ClassManifest, T: ClassManifest]( prev: RDD[T], f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean ) extends RDD[U](prev) { - override def getSplits = firstParent[T].splits + override def getPartitions: Array[Partition] = firstParent[T].partitions override val partitioner = if (preservesPartitioning) prev.partitioner else None - override def compute(split: Split, context: TaskContext) = + override def compute(split: Partition, context: TaskContext) = f(split.index, firstParent[T].iterator(split, context)) -} \ No newline at end of file +} diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala index 5466c9c657fcb03b20f578ce4456aa4c5cc0c1ed..af07311b6d0385673150866bb9288e7a270c1da9 100644 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -1,13 +1,13 @@ package spark.rdd -import spark.{RDD, Split, TaskContext} +import spark.{RDD, Partition, TaskContext} private[spark] class MappedRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: T => U) extends RDD[U](prev) { - override def getSplits = firstParent[T].splits + override def getPartitions: Array[Partition] = firstParent[T].partitions - override def compute(split: Split, context: TaskContext) = + override def compute(split: Partition, context: TaskContext) = firstParent[T].iterator(split, context).map(f) } diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index c3b155fcbddd6545b2aa45fa45c9bd610a56d9fa..df2361025c75327837c4e13184e762261e4b7509 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -7,12 +7,12 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ -import spark.{Dependency, RDD, SerializableWritable, SparkContext, Split, TaskContext} +import spark.{Dependency, Logging, Partition, RDD, SerializableWritable, SparkContext, TaskContext} private[spark] -class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable) - extends Split { +class NewHadoopPartition(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable) + extends Partition { val serializableHadoopSplit = new SerializableWritable(rawSplit) @@ -26,10 +26,11 @@ class NewHadoopRDD[K, V]( valueClass: Class[V], @transient conf: Configuration) extends RDD[(K, V)](sc, Nil) - with HadoopMapReduceUtil { + with HadoopMapReduceUtil + with Logging { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it - val confBroadcast = sc.broadcast(new SerializableWritable(conf)) + private val confBroadcast = sc.broadcast(new SerializableWritable(conf)) // private val serializableConf = new SerializableWritable(conf) private val jobtrackerId: String = { @@ -39,21 +40,19 @@ class NewHadoopRDD[K, V]( @transient private val jobId = new JobID(jobtrackerId, id) - @transient private val splits_ : Array[Split] = { + override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance val jobContext = newJobContext(conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray - val result = new Array[Split](rawSplits.size) + val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { - result(i) = new NewHadoopSplit(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) } result } - override def getSplits = splits_ - - override def compute(theSplit: Split, context: TaskContext) = new Iterator[(K, V)] { - val split = theSplit.asInstanceOf[NewHadoopSplit] + override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] { + val split = theSplit.asInstanceOf[NewHadoopPartition] val conf = confBroadcast.value.value val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0) val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) @@ -63,7 +62,7 @@ class NewHadoopRDD[K, V]( reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback(() => reader.close()) + context.addOnCompleteCallback(() => close()) var havePair = false var finished = false @@ -83,10 +82,18 @@ class NewHadoopRDD[K, V]( havePair = false return (reader.getCurrentKey, reader.getCurrentValue) } + + private def close() { + try { + reader.close() + } catch { + case e: Exception => logWarning("Exception in RecordReader.close()", e) + } + } } - override def getPreferredLocations(split: Split) = { - val theSplit = split.asInstanceOf[NewHadoopSplit] + override def getPreferredLocations(split: Partition): Seq[String] = { + val theSplit = split.asInstanceOf[NewHadoopPartition] theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") } } diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala similarity index 77% rename from core/src/main/scala/spark/ParallelCollection.scala rename to core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala index 10adcd53ecafdd83b34d232e2270eb9e63305f6e..07585a88ceb36930ee626e505d53b5030bb8b02f 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/rdd/ParallelCollectionRDD.scala @@ -1,28 +1,29 @@ -package spark +package spark.rdd import scala.collection.immutable.NumericRange import scala.collection.mutable.ArrayBuffer import scala.collection.Map +import spark.{RDD, TaskContext, SparkContext, Partition} -private[spark] class ParallelCollectionSplit[T: ClassManifest]( +private[spark] class ParallelCollectionPartition[T: ClassManifest]( val rddId: Long, val slice: Int, values: Seq[T]) - extends Split with Serializable { + extends Partition with Serializable { def iterator: Iterator[T] = values.iterator override def hashCode(): Int = (41 * (41 + rddId) + slice).toInt override def equals(other: Any): Boolean = other match { - case that: ParallelCollectionSplit[_] => (this.rddId == that.rddId && this.slice == that.slice) + case that: ParallelCollectionPartition[_] => (this.rddId == that.rddId && this.slice == that.slice) case _ => false } override val index: Int = slice } -private[spark] class ParallelCollection[T: ClassManifest]( +private[spark] class ParallelCollectionRDD[T: ClassManifest]( @transient sc: SparkContext, @transient data: Seq[T], numSlices: Int, @@ -33,26 +34,20 @@ private[spark] class ParallelCollection[T: ClassManifest]( // instead. // UPDATE: A parallel collection can be checkpointed to HDFS, which achieves this goal. - @transient var splits_ : Array[Split] = { - val slices = ParallelCollection.slice(data, numSlices).toArray - slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray + override def getPartitions: Array[Partition] = { + val slices = ParallelCollectionRDD.slice(data, numSlices).toArray + slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray } - override def getSplits = splits_ + override def compute(s: Partition, context: TaskContext) = + s.asInstanceOf[ParallelCollectionPartition[T]].iterator - override def compute(s: Split, context: TaskContext) = - s.asInstanceOf[ParallelCollectionSplit[T]].iterator - - override def getPreferredLocations(s: Split): Seq[String] = { + override def getPreferredLocations(s: Partition): Seq[String] = { locationPrefs.getOrElse(s.index, Nil) } - - override def clearDependencies() { - splits_ = null - } } -private object ParallelCollection { +private object ParallelCollectionRDD { /** * Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala index a50ce751718c4b640aaf0b3854c03c8e54d9d9e5..41ff62dd2285749b0829b29a8c9293aee4dbbcfd 100644 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -1,9 +1,9 @@ package spark.rdd -import spark.{NarrowDependency, RDD, SparkEnv, Split, TaskContext} +import spark.{NarrowDependency, RDD, SparkEnv, Partition, TaskContext} -class PartitionPruningRDDSplit(idx: Int, val parentSplit: Split) extends Split { +class PartitionPruningRDDPartition(idx: Int, val parentSplit: Partition) extends Partition { override val index = idx } @@ -16,15 +16,15 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo extends NarrowDependency[T](rdd) { @transient - val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index)) - .zipWithIndex.map { case(split, idx) => new PartitionPruningRDDSplit(idx, split) : Split } + val partitions: Array[Partition] = rdd.partitions.filter(s => partitionFilterFunc(s.index)) + .zipWithIndex.map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition } override def getParents(partitionId: Int) = List(partitions(partitionId).index) } /** - * A RDD used to prune RDD partitions/splits so we can avoid launching tasks on + * A RDD used to prune RDD partitions/partitions so we can avoid launching tasks on * all partitions. An example use case: If we know the RDD is partitioned by range, * and the execution DAG has a filter on the key, we can avoid launching tasks * on partitions that don't have the range covering the key. @@ -34,9 +34,21 @@ class PartitionPruningRDD[T: ClassManifest]( @transient partitionFilterFunc: Int => Boolean) extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { - override def compute(split: Split, context: TaskContext) = firstParent[T].iterator( - split.asInstanceOf[PartitionPruningRDDSplit].parentSplit, context) + override def compute(split: Partition, context: TaskContext) = firstParent[T].iterator( + split.asInstanceOf[PartitionPruningRDDPartition].parentSplit, context) - override protected def getSplits = + override protected def getPartitions: Array[Partition] = getDependencies.head.asInstanceOf[PruneDependency[T]].partitions } + + +object PartitionPruningRDD { + + /** + * Create a PartitionPruningRDD. This function can be used to create the PartitionPruningRDD + * when its type T is not known at compile time. + */ + def create[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean) = { + new PartitionPruningRDD[T](rdd, partitionFilterFunc)(rdd.elementClassManifest) + } +} diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index 6631f83510cb6851a9a6002792415d6ca556f1c8..962a1b21ad1d3ff8c32e461984ec444330e4bf67 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -8,7 +8,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import scala.io.Source -import spark.{RDD, SparkEnv, Split, TaskContext} +import spark.{RDD, SparkEnv, Partition, TaskContext} /** @@ -27,9 +27,9 @@ class PipedRDD[T: ClassManifest]( // using a standard StringTokenizer (i.e. by spaces) def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command)) - override def getSplits = firstParent[T].splits + override def getPartitions: Array[Partition] = firstParent[T].partitions - override def compute(split: Split, context: TaskContext): Iterator[String] = { + override def compute(split: Partition, context: TaskContext): Iterator[String] = { val pb = new ProcessBuilder(command) // Add the environmental variables to the process. val currentEnvVars = pb.environment() diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index e24ad23b2142315a67d6f79c53fc9b0282c613a7..243673f1518729f6228f8f857ad6fc19b29fd08e 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -5,10 +5,10 @@ import java.util.Random import cern.jet.random.Poisson import cern.jet.random.engine.DRand -import spark.{RDD, Split, TaskContext} +import spark.{RDD, Partition, TaskContext} private[spark] -class SampledRDDSplit(val prev: Split, val seed: Int) extends Split with Serializable { +class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable { override val index: Int = prev.index } @@ -19,18 +19,16 @@ class SampledRDD[T: ClassManifest]( seed: Int) extends RDD[T](prev) { - @transient var splits_ : Array[Split] = { + override def getPartitions: Array[Partition] = { val rg = new Random(seed) - firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt)) + firstParent[T].partitions.map(x => new SampledRDDPartition(x, rg.nextInt)) } - override def getSplits = splits_ + override def getPreferredLocations(split: Partition): Seq[String] = + firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDPartition].prev) - override def getPreferredLocations(split: Split) = - firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) - - override def compute(splitIn: Split, context: TaskContext) = { - val split = splitIn.asInstanceOf[SampledRDDSplit] + override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = { + val split = splitIn.asInstanceOf[SampledRDDPartition] if (withReplacement) { // For large datasets, the expected number of occurrences of each element in a sample with // replacement is Poisson(frac). We use that to get a count for each element. @@ -48,8 +46,4 @@ class SampledRDD[T: ClassManifest]( firstParent[T].iterator(split.prev, context).filter(x => (rand.nextDouble <= frac)) } } - - override def clearDependencies() { - splits_ = null - } } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index d3964786736156e3378fbef1d50f306e99ba0607..c2f118305f33f260b17af5bf49dbd1d5fd11b538 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -1,9 +1,9 @@ package spark.rdd -import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Split, TaskContext} +import spark.{Partitioner, RDD, SparkEnv, ShuffleDependency, Partition, TaskContext} import spark.SparkContext._ -private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { +private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { override val index = idx override def hashCode(): Int = idx } @@ -22,9 +22,11 @@ class ShuffledRDD[K, V]( override val partitioner = Some(part) - override def getSplits = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) + override def getPartitions: Array[Partition] = { + Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i)) + } - override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = { + override def compute(split: Partition, context: TaskContext): Iterator[(K, V)] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index) } diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala new file mode 100644 index 0000000000000000000000000000000000000000..daf9cc993cf42e9e963e986d73cdad0d6d708059 --- /dev/null +++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala @@ -0,0 +1,108 @@ +package spark.rdd + +import java.util.{HashSet => JHashSet} +import scala.collection.JavaConversions._ +import spark.RDD +import spark.Partitioner +import spark.Dependency +import spark.TaskContext +import spark.Partition +import spark.SparkEnv +import spark.ShuffleDependency +import spark.OneToOneDependency + +/** + * An optimized version of cogroup for set difference/subtraction. + * + * It is possible to implement this operation with just `cogroup`, but + * that is less efficient because all of the entries from `rdd2`, for + * both matching and non-matching values in `rdd1`, are kept in the + * JHashMap until the end. + * + * With this implementation, only the entries from `rdd1` are kept in-memory, + * and the entries from `rdd2` are essentially streamed, as we only need to + * touch each once to decide if the value needs to be removed. + * + * This is particularly helpful when `rdd1` is much smaller than `rdd2`, as + * you can use `rdd1`'s partitioner/partition size and not worry about running + * out of memory because of the size of `rdd2`. + */ +private[spark] class SubtractedRDD[T: ClassManifest]( + @transient var rdd1: RDD[T], + @transient var rdd2: RDD[T], + part: Partitioner) extends RDD[T](rdd1.context, Nil) { + + override def getDependencies: Seq[Dependency[_]] = { + Seq(rdd1, rdd2).map { rdd => + if (rdd.partitioner == Some(part)) { + logInfo("Adding one-to-one dependency with " + rdd) + new OneToOneDependency(rdd) + } else { + logInfo("Adding shuffle dependency with " + rdd) + val mapSideCombinedRDD = rdd.mapPartitions(i => { + val set = new JHashSet[T]() + while (i.hasNext) { + set.add(i.next) + } + set.iterator + }, true) + // ShuffleDependency requires a tuple (k, v), which it will partition by k. + // We need this to partition to map to the same place as the k for + // OneToOneDependency, which means: + // - for already-tupled RDD[(A, B)], into getPartition(a) + // - for non-tupled RDD[C], into getPartition(c) + val part2 = new Partitioner() { + def numPartitions = part.numPartitions + def getPartition(key: Any) = key match { + case (k, v) => part.getPartition(k) + case k => part.getPartition(k) + } + } + new ShuffleDependency(mapSideCombinedRDD.map((_, null)), part2) + } + } + } + + override def getPartitions: Array[Partition] = { + val array = new Array[Partition](part.numPartitions) + for (i <- 0 until array.size) { + // Each CoGroupPartition will depend on rdd1 and rdd2 + array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) => + dependencies(j) match { + case s: ShuffleDependency[_, _] => + new ShuffleCoGroupSplitDep(s.shuffleId) + case _ => + new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) + } + }.toList) + } + array + } + + override val partitioner = Some(part) + + override def compute(p: Partition, context: TaskContext): Iterator[T] = { + val partition = p.asInstanceOf[CoGroupPartition] + val set = new JHashSet[T] + def integrate(dep: CoGroupSplitDep, op: T => Unit) = dep match { + case NarrowCoGroupSplitDep(rdd, _, itsSplit) => + for (k <- rdd.iterator(itsSplit, context)) + op(k.asInstanceOf[T]) + case ShuffleCoGroupSplitDep(shuffleId) => + for ((k, _) <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index)) + op(k.asInstanceOf[T]) + } + // the first dep is rdd1; add all keys to the set + integrate(partition.deps(0), set.add) + // the second dep is rdd2; remove all of its keys from the set + integrate(partition.deps(1), set.remove) + set.iterator + } + + override def clearDependencies() { + super.clearDependencies() + rdd1 = null + rdd2 = null + } + +} \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index 26a2d511f2670621de66d5c5f9f271ff130ce031..2c52a67e22635bf196a2d5341bf0972a2ab3075c 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -1,13 +1,13 @@ package spark.rdd import scala.collection.mutable.ArrayBuffer -import spark.{Dependency, RangeDependency, RDD, SparkContext, Split, TaskContext} +import spark.{Dependency, RangeDependency, RDD, SparkContext, Partition, TaskContext} import java.io.{ObjectOutputStream, IOException} -private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int) - extends Split { +private[spark] class UnionPartition[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int) + extends Partition { - var split: Split = rdd.splits(splitIndex) + var split: Partition = rdd.partitions(splitIndex) def iterator(context: TaskContext) = rdd.iterator(split, context) @@ -18,7 +18,7 @@ private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIn @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { // Update the reference to parent split at the time of task serialization - split = rdd.splits(splitIndex) + split = rdd.partitions(splitIndex) oos.defaultWriteObject() } } @@ -28,11 +28,11 @@ class UnionRDD[T: ClassManifest]( @transient var rdds: Seq[RDD[T]]) extends RDD[T](sc, Nil) { // Nil since we implement getDependencies - override def getSplits: Array[Split] = { - val array = new Array[Split](rdds.map(_.splits.size).sum) + override def getPartitions: Array[Partition] = { + val array = new Array[Partition](rdds.map(_.partitions.size).sum) var pos = 0 - for (rdd <- rdds; split <- rdd.splits) { - array(pos) = new UnionSplit(pos, rdd, split.index) + for (rdd <- rdds; split <- rdd.partitions) { + array(pos) = new UnionPartition(pos, rdd, split.index) pos += 1 } array @@ -42,19 +42,15 @@ class UnionRDD[T: ClassManifest]( val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for (rdd <- rdds) { - deps += new RangeDependency(rdd, 0, pos, rdd.splits.size) - pos += rdd.splits.size + deps += new RangeDependency(rdd, 0, pos, rdd.partitions.size) + pos += rdd.partitions.size } deps } - override def compute(s: Split, context: TaskContext): Iterator[T] = - s.asInstanceOf[UnionSplit[T]].iterator(context) + override def compute(s: Partition, context: TaskContext): Iterator[T] = + s.asInstanceOf[UnionPartition[T]].iterator(context) - override def getPreferredLocations(s: Split): Seq[String] = - s.asInstanceOf[UnionSplit[T]].preferredLocations() - - override def clearDependencies() { - rdds = null - } + override def getPreferredLocations(s: Partition): Seq[String] = + s.asInstanceOf[UnionPartition[T]].preferredLocations() } diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala index e5df6d8c7239b83a328f1b898e8c395c9c3e459c..e80ec17aa5e0a4bbffd94b835fb18cff597c1f00 100644 --- a/core/src/main/scala/spark/rdd/ZippedRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala @@ -1,17 +1,17 @@ package spark.rdd -import spark.{OneToOneDependency, RDD, SparkContext, Split, TaskContext} +import spark.{OneToOneDependency, RDD, SparkContext, Partition, TaskContext} import java.io.{ObjectOutputStream, IOException} -private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest]( +private[spark] class ZippedPartition[T: ClassManifest, U: ClassManifest]( idx: Int, @transient rdd1: RDD[T], @transient rdd2: RDD[U] - ) extends Split { + ) extends Partition { - var split1 = rdd1.splits(idx) - var split2 = rdd1.splits(idx) + var split1 = rdd1.partitions(idx) + var split2 = rdd1.partitions(idx) override val index: Int = idx def splits = (split1, split2) @@ -19,8 +19,8 @@ private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest]( @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { // Update the reference to parent split at the time of task serialization - split1 = rdd1.splits(idx) - split2 = rdd2.splits(idx) + split1 = rdd1.partitions(idx) + split2 = rdd2.partitions(idx) oos.defaultWriteObject() } } @@ -29,31 +29,31 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest]( sc: SparkContext, var rdd1: RDD[T], var rdd2: RDD[U]) - extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))) - with Serializable { + extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))) { - override def getSplits: Array[Split] = { - if (rdd1.splits.size != rdd2.splits.size) { + override def getPartitions: Array[Partition] = { + if (rdd1.partitions.size != rdd2.partitions.size) { throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") } - val array = new Array[Split](rdd1.splits.size) - for (i <- 0 until rdd1.splits.size) { - array(i) = new ZippedSplit(i, rdd1, rdd2) + val array = new Array[Partition](rdd1.partitions.size) + for (i <- 0 until rdd1.partitions.size) { + array(i) = new ZippedPartition(i, rdd1, rdd2) } array } - override def compute(s: Split, context: TaskContext): Iterator[(T, U)] = { - val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits + override def compute(s: Partition, context: TaskContext): Iterator[(T, U)] = { + val (split1, split2) = s.asInstanceOf[ZippedPartition[T, U]].splits rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context)) } - override def getPreferredLocations(s: Split): Seq[String] = { - val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits + override def getPreferredLocations(s: Partition): Seq[String] = { + val (split1, split2) = s.asInstanceOf[ZippedPartition[T, U]].splits rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2)) } override def clearDependencies() { + super.clearDependencies() rdd1 = null rdd2 = null } diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 319eef69780edce1dad13449fdc22bdc612d715d..bf0837c0660cb2090a86a4aa75ec34bcd8b41b32 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -106,7 +106,7 @@ class DAGScheduler( private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { if (!cacheLocs.contains(rdd.id)) { - val blockIds = rdd.splits.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray + val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map { locations => locations.map(_.ip).toList }.toArray @@ -141,9 +141,9 @@ class DAGScheduler( private def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = { if (shuffleDep != None) { // Kind of ugly: need to register RDDs with the cache and map output tracker here - // since we can't do it in the RDD constructor because # of splits is unknown + // since we can't do it in the RDD constructor because # of partitions is unknown logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") - mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size) + mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size) } val id = nextStageId.getAndIncrement() val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority), priority) @@ -162,7 +162,7 @@ class DAGScheduler( if (!visited(r)) { visited += r // Kind of ugly: need to register RDDs with the cache here since - // we can't do it in its constructor because # of splits is unknown + // we can't do it in its constructor because # of partitions is unknown for (dep <- r.dependencies) { dep match { case shufDep: ShuffleDependency[_,_] => @@ -257,7 +257,7 @@ class DAGScheduler( { val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - val partitions = (0 until rdd.splits.size).toArray + val partitions = (0 until rdd.partitions.size).toArray eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener)) return listener.awaitResult() // Will throw an exception if the job fails } @@ -386,7 +386,7 @@ class DAGScheduler( try { SparkEnv.set(env) val rdd = job.finalStage.rdd - val split = rdd.splits(job.partitions(0)) + val split = rdd.partitions(job.partitions(0)) val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) @@ -672,7 +672,7 @@ class DAGScheduler( return cached } // If the RDD has some placement preferences (as is the case for input RDDs), get those - val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList + val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList if (rddPrefs != Nil) { return rddPrefs } diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala index 8cd4c661eb70d246c80bfbfe9b19580582928ae0..1721f78f483cf9f8b274e81ef6e2c01327a8b277 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -67,7 +67,7 @@ private[spark] class ResultTask[T, U]( var split = if (rdd == null) { null } else { - rdd.splits(partition) + rdd.partitions(partition) } override def run(attemptId: Long): U = { @@ -85,7 +85,7 @@ private[spark] class ResultTask[T, U]( override def writeExternal(out: ObjectOutput) { RDDCheckpointData.synchronized { - split = rdd.splits(partition) + split = rdd.partitions(partition) out.writeInt(stageId) val bytes = ResultTask.serializeInfo( stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _]) @@ -107,6 +107,6 @@ private[spark] class ResultTask[T, U]( func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U] partition = in.readInt() val outputId = in.readInt() - split = in.readObject().asInstanceOf[Split] + split = in.readObject().asInstanceOf[Partition] } } diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index bed9f1864fffda3d7029bf6675aa95cf300a204a..59ee3c0a095acc3eab7bedc5c4189538ac3f91cc 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -86,12 +86,12 @@ private[spark] class ShuffleMapTask( var split = if (rdd == null) { null } else { - rdd.splits(partition) + rdd.partitions(partition) } override def writeExternal(out: ObjectOutput) { RDDCheckpointData.synchronized { - split = rdd.splits(partition) + split = rdd.partitions(partition) out.writeInt(stageId) val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) out.writeInt(bytes.length) @@ -112,7 +112,7 @@ private[spark] class ShuffleMapTask( dep = dep_ partition = in.readInt() generation = in.readLong() - split = in.readObject().asInstanceOf[Split] + split = in.readObject().asInstanceOf[Partition] } override def run(attemptId: Long): MapStatus = { diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala index 374114d87034cf1851ade92c36fd42bbc72c6e82..552061e46b8c18d53ac06496a9b27c8cc37c5b80 100644 --- a/core/src/main/scala/spark/scheduler/Stage.scala +++ b/core/src/main/scala/spark/scheduler/Stage.scala @@ -28,7 +28,7 @@ private[spark] class Stage( extends Logging { val isShuffleMap = shuffleDep != None - val numPartitions = rdd.splits.size + val numPartitions = rdd.partitions.size val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) var numAvailableOutputs = 0 diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala index 1e4fbdb8742fdac8d33edb6b1e9ec1b3aaaea4d4..d9c2f9517be50ecd6e9370b5f55a4a3032dae6ab 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala @@ -11,6 +11,7 @@ import spark.TaskState.TaskState import spark.scheduler._ import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicLong +import java.util.{TimerTask, Timer} /** * The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call @@ -22,6 +23,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext) // How often to check for speculative tasks val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong + // Threshold above which we warn user initial TaskSet may be starved + val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong val activeTaskSets = new HashMap[String, TaskSetManager] var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] @@ -30,6 +33,10 @@ private[spark] class ClusterScheduler(val sc: SparkContext) val taskIdToExecutorId = new HashMap[Long, String] val taskSetTaskIds = new HashMap[String, HashSet[Long]] + var hasReceivedTask = false + var hasLaunchedTask = false + val starvationTimer = new Timer(true) + // Incrementing Mesos task IDs val nextTaskId = new AtomicLong(0) @@ -94,6 +101,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext) activeTaskSets(taskSet.id) = manager activeTaskSetsQueue += manager taskSetTaskIds(taskSet.id) = new HashSet[Long]() + + if (hasReceivedTask == false) { + starvationTimer.scheduleAtFixedRate(new TimerTask() { + override def run() { + if (!hasLaunchedTask) { + logWarning("Initial job has not accepted any resources; " + + "check your cluster UI to ensure that workers are registered") + } else { + this.cancel() + } + } + }, STARVATION_TIMEOUT, STARVATION_TIMEOUT) + } + hasReceivedTask = true; } backend.reviveOffers() } @@ -150,6 +171,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } while (launchedTask) } + if (tasks.size > 0) { + hasLaunchedTask = true + } return tasks } } @@ -235,7 +259,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } override def defaultParallelism() = backend.defaultParallelism() - + // Check for speculatable tasks in all our active jobs. def checkSpeculatableTasks() { var shouldRevive = false diff --git a/core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala b/core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala index bba7de6a65c3d17aab47bdfa07c464ee7e801604..8bf838209f3d8183cf6251bd7def1b8f14dffa75 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ExecutorLossReason.scala @@ -12,10 +12,10 @@ class ExecutorLossReason(val message: String) { private[spark] case class ExecutorExited(val exitCode: Int) - extends ExecutorLossReason(ExecutorExitCode.explainExitCode(exitCode)) { + extends ExecutorLossReason(ExecutorExitCode.explainExitCode(exitCode)) { } private[spark] case class SlaveLost(_message: String = "Slave lost") - extends ExecutorLossReason(_message) { + extends ExecutorLossReason(_message) { } diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 59ff8bcb90fc2620de2b538860cff59e9a2446c8..bb289c9cf391b4d1dcc6f94113117b9468c4a578 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -2,14 +2,14 @@ package spark.scheduler.cluster import spark.{Utils, Logging, SparkContext} import spark.deploy.client.{Client, ClientListener} -import spark.deploy.{Command, JobDescription} +import spark.deploy.{Command, ApplicationDescription} import scala.collection.mutable.HashMap private[spark] class SparkDeploySchedulerBackend( scheduler: ClusterScheduler, sc: SparkContext, master: String, - jobName: String) + appName: String) extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem) with ClientListener with Logging { @@ -29,10 +29,11 @@ private[spark] class SparkDeploySchedulerBackend( StandaloneSchedulerBackend.ACTOR_NAME) val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) - val sparkHome = sc.getSparkHome().getOrElse(throw new IllegalArgumentException("must supply spark home for spark standalone")) - val jobDesc = new JobDescription(jobName, maxCores, executorMemory, command, sparkHome) + val sparkHome = sc.getSparkHome().getOrElse( + throw new IllegalArgumentException("must supply spark home for spark standalone")) + val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome) - client = new Client(sc.env.actorSystem, master, jobDesc, this) + client = new Client(sc.env.actorSystem, master, appDesc, this) client.start() } @@ -45,8 +46,8 @@ private[spark] class SparkDeploySchedulerBackend( } } - override def connected(jobId: String) { - logInfo("Connected to Spark cluster with job ID " + jobId) + override def connected(appId: String) { + logInfo("Connected to Spark cluster with app ID " + appId) } override def disconnected() { @@ -67,6 +68,6 @@ private[spark] class SparkDeploySchedulerBackend( case None => SlaveLost(message) } logInfo("Executor %s removed: %s".format(executorId, message)) - scheduler.executorLost(executorId, reason) + removeExecutor(executorId, reason.toString) } } diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala index da7dcf4b6b48e8b5eb851fbef8d48d79e20dc09e..d7660678248b2d9f028d8364bd3caa8cbd47660c 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneClusterMessage.scala @@ -37,3 +37,6 @@ object StatusUpdate { // Internal messages in driver private[spark] case object ReviveOffers extends StandaloneClusterMessage private[spark] case object StopDriver extends StandaloneClusterMessage + +private[spark] case class RemoveExecutor(executorId: String, reason: String) + extends StandaloneClusterMessage diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 082022be1c9da0a487e65879e57814b793ebe838..7a428e3361976b3a8d78b8a8c1c1a7f771399c1a 100644 --- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -68,6 +68,10 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor sender ! true context.stop(self) + case RemoveExecutor(executorId, reason) => + removeExecutor(executorId, reason) + sender ! true + case Terminated(actor) => actorToExecutorId.get(actor).foreach(removeExecutor(_, "Akka actor terminated")) @@ -100,16 +104,18 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor // Remove a disconnected slave from the cluster def removeExecutor(executorId: String, reason: String) { - logInfo("Slave " + executorId + " disconnected, so removing it") - val numCores = freeCores(executorId) - actorToExecutorId -= executorActor(executorId) - addressToExecutorId -= executorAddress(executorId) - executorActor -= executorId - executorHost -= executorId - freeCores -= executorId - executorHost -= executorId - totalCoreCount.addAndGet(-numCores) - scheduler.executorLost(executorId, SlaveLost(reason)) + if (executorActor.contains(executorId)) { + logInfo("Executor " + executorId + " disconnected, so removing it") + val numCores = freeCores(executorId) + actorToExecutorId -= executorActor(executorId) + addressToExecutorId -= executorAddress(executorId) + executorActor -= executorId + executorHost -= executorId + freeCores -= executorId + executorHost -= executorId + totalCoreCount.addAndGet(-numCores) + scheduler.executorLost(executorId, SlaveLost(reason)) + } } } @@ -139,7 +145,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } } catch { case e: Exception => - throw new SparkException("Error stopping standalone scheduler's master actor", e) + throw new SparkException("Error stopping standalone scheduler's driver actor", e) } } @@ -147,7 +153,20 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor driverActor ! ReviveOffers } - override def defaultParallelism(): Int = math.max(totalCoreCount.get(), 2) + override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism")) + .map(_.toInt).getOrElse(math.max(totalCoreCount.get(), 2)) + + // Called by subclasses when notified of a lost worker + def removeExecutor(executorId: String, reason: String) { + try { + val timeout = 5.seconds + val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout) + Await.result(future, timeout) + } catch { + case e: Exception => + throw new SparkException("Error notifying standalone scheduler's driver actor", e) + } + } } private[spark] object StandaloneSchedulerBackend { diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala index b481ec0a72dce5937f773ca03fa1f9ae6c9eda9d..f4a2994b6d61d5b74ff8e331a2558f90fb4a8baf 100644 --- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosSchedulerBackend.scala @@ -28,7 +28,7 @@ private[spark] class CoarseMesosSchedulerBackend( scheduler: ClusterScheduler, sc: SparkContext, master: String, - frameworkName: String) + appName: String) extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem) with MScheduler with Logging { @@ -76,7 +76,7 @@ private[spark] class CoarseMesosSchedulerBackend( setDaemon(true) override def run() { val scheduler = CoarseMesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(frameworkName).build() + val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(appName).build() driver = new MesosSchedulerDriver(scheduler, fwInfo, master) try { { val ret = driver.run() @@ -239,7 +239,11 @@ private[spark] class CoarseMesosSchedulerBackend( override def slaveLost(d: SchedulerDriver, slaveId: SlaveID) { logInfo("Mesos slave lost: " + slaveId.getValue) synchronized { - slaveIdsWithExecutors -= slaveId.getValue + if (slaveIdsWithExecutors.contains(slaveId.getValue)) { + // Note that the slave ID corresponds to the executor ID on that slave + slaveIdsWithExecutors -= slaveId.getValue + removeExecutor(slaveId.getValue, "Mesos slave lost") + } } } diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala index 300766d0f5c729d84c0b152f889e5996111aeb4d..ca7fab4cc5ff95f318b33fc2b82e09f870d83f03 100644 --- a/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/mesos/MesosSchedulerBackend.scala @@ -24,7 +24,7 @@ private[spark] class MesosSchedulerBackend( scheduler: ClusterScheduler, sc: SparkContext, master: String, - frameworkName: String) + appName: String) extends SchedulerBackend with MScheduler with Logging { @@ -49,7 +49,7 @@ private[spark] class MesosSchedulerBackend( setDaemon(true) override def run() { val scheduler = MesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(frameworkName).build() + val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(appName).build() driver = new MesosSchedulerDriver(scheduler, fwInfo, master) try { val ret = driver.run() diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index d9838f65ab4212d0eab3dbe6350387baabe300f9..266191b05f23882f71eea067b9dd3e798c313cd1 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -513,7 +513,7 @@ class BlockManager( } } - // Split local and remote blocks. Remote blocks are further split into FetchRequests of size + // Partition local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. val remoteRequests = new ArrayBuffer[FetchRequest] for ((address, blockInfos) <- blocksByAddress) { @@ -585,7 +585,7 @@ class BlockManager( resultsGotten += 1 val result = results.take() bytesInFlight -= result.size - if (!fetchRequests.isEmpty && + while (!fetchRequests.isEmpty && (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { sendRequest(fetchRequests.dequeue()) } diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index 7e5b820cbbdc6ca145c2eb7c6787bd2c137c80d0..ddbf8821ad15aa172cb248ef3aed45edb3e38f33 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -178,7 +178,11 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { override def run() { logDebug("Shutdown hook called") - localDirs.foreach(localDir => Utils.deleteRecursively(localDir)) + try { + localDirs.foreach(localDir => Utils.deleteRecursively(localDir)) + } catch { + case t: Throwable => logError("Exception while deleting local spark dirs", t) + } } }) } diff --git a/core/src/main/scala/spark/storage/StorageUtils.scala b/core/src/main/scala/spark/storage/StorageUtils.scala index 5f72b67b2cc27caf4eba9ad08d3fd079374dfcef..dec47a9d4113b40dcd4fe25f95a021b3cc1f681f 100644 --- a/core/src/main/scala/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/spark/storage/StorageUtils.scala @@ -63,7 +63,7 @@ object StorageUtils { val rddName = Option(rdd.name).getOrElse(rddKey) val rddStorageLevel = rdd.getStorageLevel - RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, rdd.splits.size, memSize, diskSize) + RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, rdd.partitions.size, memSize, diskSize) }.toArray } diff --git a/core/src/main/scala/spark/util/MetadataCleaner.scala b/core/src/main/scala/spark/util/MetadataCleaner.scala index a342d378ffab263f8dc23b58ec7e7f1f204bd045..dafa90671214b4d803f2775534ebee9f5325cc70 100644 --- a/core/src/main/scala/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/spark/util/MetadataCleaner.scala @@ -38,7 +38,7 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging object MetadataCleaner { - def getDelaySeconds = System.getProperty("spark.cleaner.delay", "-1").toInt - def setDelaySeconds(delay: Int) { System.setProperty("spark.cleaner.delay", delay.toString) } + def getDelaySeconds = System.getProperty("spark.cleaner.ttl", "-1").toInt + def setDelaySeconds(delay: Int) { System.setProperty("spark.cleaner.ttl", delay.toString) } } diff --git a/core/src/main/scala/spark/util/Vector.scala b/core/src/main/scala/spark/util/Vector.scala index 03559751bc46bea21e8d455323459cc91d73ea9c..835822edb2300969ea40497e9e954f85921ffd5e 100644 --- a/core/src/main/scala/spark/util/Vector.scala +++ b/core/src/main/scala/spark/util/Vector.scala @@ -11,12 +11,16 @@ class Vector(val elements: Array[Double]) extends Serializable { return Vector(length, i => this(i) + other(i)) } + def add(other: Vector) = this + other + def - (other: Vector): Vector = { if (length != other.length) throw new IllegalArgumentException("Vectors of different length") return Vector(length, i => this(i) - other(i)) } + def subtract(other: Vector) = this - other + def dot(other: Vector): Double = { if (length != other.length) throw new IllegalArgumentException("Vectors of different length") @@ -61,10 +65,16 @@ class Vector(val elements: Array[Double]) extends Serializable { this } + def addInPlace(other: Vector) = this +=other + def * (scale: Double): Vector = Vector(length, i => this(i) * scale) + def multiply (d: Double) = this * d + def / (d: Double): Vector = this * (1 / d) + def divide (d: Double) = this / d + def unary_- = this * -1 def sum = elements.reduceLeft(_ + _) diff --git a/core/src/main/twirl/spark/deploy/master/app_details.scala.html b/core/src/main/twirl/spark/deploy/master/app_details.scala.html new file mode 100644 index 0000000000000000000000000000000000000000..301a7e212495d5fd9454be3f8620d41e34a58e24 --- /dev/null +++ b/core/src/main/twirl/spark/deploy/master/app_details.scala.html @@ -0,0 +1,40 @@ +@(app: spark.deploy.master.ApplicationInfo) + +@spark.common.html.layout(title = "Application Details") { + + <!-- Application Details --> + <div class="row"> + <div class="span12"> + <ul class="unstyled"> + <li><strong>ID:</strong> @app.id</li> + <li><strong>Description:</strong> @app.desc.name</li> + <li><strong>User:</strong> @app.desc.user</li> + <li><strong>Cores:</strong> + @app.desc.cores + (@app.coresGranted Granted + @if(app.desc.cores == Integer.MAX_VALUE) { + + } else { + , @app.coresLeft + } + ) + </li> + <li><strong>Memory per Slave:</strong> @app.desc.memoryPerSlave</li> + <li><strong>Submit Date:</strong> @app.submitDate</li> + <li><strong>State:</strong> @app.state</li> + </ul> + </div> + </div> + + <hr/> + + <!-- Executors --> + <div class="row"> + <div class="span12"> + <h3> Executor Summary </h3> + <br/> + @executors_table(app.executors.values.toList) + </div> + </div> + +} diff --git a/core/src/main/twirl/spark/deploy/master/app_row.scala.html b/core/src/main/twirl/spark/deploy/master/app_row.scala.html new file mode 100644 index 0000000000000000000000000000000000000000..feb306f35ccb58d6b0097dd72584589034364e90 --- /dev/null +++ b/core/src/main/twirl/spark/deploy/master/app_row.scala.html @@ -0,0 +1,20 @@ +@(app: spark.deploy.master.ApplicationInfo) + +@import spark.Utils +@import spark.deploy.WebUI.formatDate +@import spark.deploy.WebUI.formatDuration + +<tr> + <td> + <a href="app?appId=@(app.id)">@app.id</a> + </td> + <td>@app.desc.name</td> + <td> + @app.coresGranted + </td> + <td>@Utils.memoryMegabytesToString(app.desc.memoryPerSlave)</td> + <td>@formatDate(app.submitDate)</td> + <td>@app.desc.user</td> + <td>@app.state.toString()</td> + <td>@formatDuration(app.duration)</td> +</tr> diff --git a/core/src/main/twirl/spark/deploy/master/job_table.scala.html b/core/src/main/twirl/spark/deploy/master/app_table.scala.html similarity index 74% rename from core/src/main/twirl/spark/deploy/master/job_table.scala.html rename to core/src/main/twirl/spark/deploy/master/app_table.scala.html index d267d6e85e0b7acb73a120c81fe71692e59d7dce..f789cee0f16ea834cecea5a76634d310a5929d66 100644 --- a/core/src/main/twirl/spark/deploy/master/job_table.scala.html +++ b/core/src/main/twirl/spark/deploy/master/app_table.scala.html @@ -1,9 +1,9 @@ -@(jobs: Array[spark.deploy.master.JobInfo]) +@(apps: Array[spark.deploy.master.ApplicationInfo]) <table class="table table-bordered table-striped table-condensed sortable"> <thead> <tr> - <th>JobID</th> + <th>ID</th> <th>Description</th> <th>Cores</th> <th>Memory per Node</th> @@ -14,8 +14,8 @@ </tr> </thead> <tbody> - @for(j <- jobs) { - @job_row(j) + @for(j <- apps) { + @app_row(j) } </tbody> </table> diff --git a/core/src/main/twirl/spark/deploy/master/executor_row.scala.html b/core/src/main/twirl/spark/deploy/master/executor_row.scala.html index 784d692fc2f91c2faa964ed76c41dfce91876e65..d2d80fad489920d1276742cc4097b37e1a639563 100644 --- a/core/src/main/twirl/spark/deploy/master/executor_row.scala.html +++ b/core/src/main/twirl/spark/deploy/master/executor_row.scala.html @@ -9,7 +9,7 @@ <td>@executor.memory</td> <td>@executor.state</td> <td> - <a href="@(executor.worker.webUiAddress)/log?jobId=@(executor.job.id)&executorId=@(executor.id)&logType=stdout">stdout</a> - <a href="@(executor.worker.webUiAddress)/log?jobId=@(executor.job.id)&executorId=@(executor.id)&logType=stderr">stderr</a> + <a href="@(executor.worker.webUiAddress)/log?appId=@(executor.application.id)&executorId=@(executor.id)&logType=stdout">stdout</a> + <a href="@(executor.worker.webUiAddress)/log?appId=@(executor.application.id)&executorId=@(executor.id)&logType=stderr">stderr</a> </td> -</tr> \ No newline at end of file +</tr> diff --git a/core/src/main/twirl/spark/deploy/master/index.scala.html b/core/src/main/twirl/spark/deploy/master/index.scala.html index 285645c38989504920f0f049475c4788f80c037c..ac51a39a5199d4cd96636f602ba0a9e7ce8b23c0 100644 --- a/core/src/main/twirl/spark/deploy/master/index.scala.html +++ b/core/src/main/twirl/spark/deploy/master/index.scala.html @@ -2,19 +2,19 @@ @import spark.deploy.master._ @import spark.Utils -@spark.common.html.layout(title = "Spark Master on " + state.uri) { - +@spark.common.html.layout(title = "Spark Master on " + state.host) { + <!-- Cluster Details --> <div class="row"> <div class="span12"> <ul class="unstyled"> - <li><strong>URL:</strong> spark://@(state.uri)</li> + <li><strong>URL:</strong> @(state.uri)</li> <li><strong>Workers:</strong> @state.workers.size </li> <li><strong>Cores:</strong> @{state.workers.map(_.cores).sum} Total, @{state.workers.map(_.coresUsed).sum} Used</li> <li><strong>Memory:</strong> @{Utils.memoryMegabytesToString(state.workers.map(_.memory).sum)} Total, @{Utils.memoryMegabytesToString(state.workers.map(_.memoryUsed).sum)} Used</li> - <li><strong>Jobs:</strong> @state.activeJobs.size Running, @state.completedJobs.size Completed </li> + <li><strong>Applications:</strong> @state.activeApps.size Running, @state.completedApps.size Completed </li> </ul> </div> </div> @@ -22,7 +22,7 @@ <!-- Worker Summary --> <div class="row"> <div class="span12"> - <h3> Cluster Summary </h3> + <h3> Workers </h3> <br/> @worker_table(state.workers.sortBy(_.id)) </div> @@ -30,23 +30,23 @@ <hr/> - <!-- Job Summary (Running) --> + <!-- App Summary (Running) --> <div class="row"> <div class="span12"> - <h3> Running Jobs </h3> + <h3> Running Applications </h3> <br/> - @job_table(state.activeJobs.sortBy(_.startTime).reverse) + @app_table(state.activeApps.sortBy(_.startTime).reverse) </div> </div> <hr/> - <!-- Job Summary (Completed) --> + <!-- App Summary (Completed) --> <div class="row"> <div class="span12"> - <h3> Completed Jobs </h3> + <h3> Completed Applications </h3> <br/> - @job_table(state.completedJobs.sortBy(_.endTime).reverse) + @app_table(state.completedApps.sortBy(_.endTime).reverse) </div> </div> diff --git a/core/src/main/twirl/spark/deploy/master/job_details.scala.html b/core/src/main/twirl/spark/deploy/master/job_details.scala.html deleted file mode 100644 index d02a51b214180c26df713d583b23d9fc87db7872..0000000000000000000000000000000000000000 --- a/core/src/main/twirl/spark/deploy/master/job_details.scala.html +++ /dev/null @@ -1,40 +0,0 @@ -@(job: spark.deploy.master.JobInfo) - -@spark.common.html.layout(title = "Job Details") { - - <!-- Job Details --> - <div class="row"> - <div class="span12"> - <ul class="unstyled"> - <li><strong>ID:</strong> @job.id</li> - <li><strong>Description:</strong> @job.desc.name</li> - <li><strong>User:</strong> @job.desc.user</li> - <li><strong>Cores:</strong> - @job.desc.cores - (@job.coresGranted Granted - @if(job.desc.cores == Integer.MAX_VALUE) { - - } else { - , @job.coresLeft - } - ) - </li> - <li><strong>Memory per Slave:</strong> @job.desc.memoryPerSlave</li> - <li><strong>Submit Date:</strong> @job.submitDate</li> - <li><strong>State:</strong> @job.state</li> - </ul> - </div> - </div> - - <hr/> - - <!-- Executors --> - <div class="row"> - <div class="span12"> - <h3> Executor Summary </h3> - <br/> - @executors_table(job.executors.values.toList) - </div> - </div> - -} diff --git a/core/src/main/twirl/spark/deploy/master/job_row.scala.html b/core/src/main/twirl/spark/deploy/master/job_row.scala.html deleted file mode 100644 index 7c466a6a2ce4d4ec722cd2435b7d084b35dc5d7f..0000000000000000000000000000000000000000 --- a/core/src/main/twirl/spark/deploy/master/job_row.scala.html +++ /dev/null @@ -1,20 +0,0 @@ -@(job: spark.deploy.master.JobInfo) - -@import spark.Utils -@import spark.deploy.WebUI.formatDate -@import spark.deploy.WebUI.formatDuration - -<tr> - <td> - <a href="job?jobId=@(job.id)">@job.id</a> - </td> - <td>@job.desc.name</td> - <td> - @job.coresGranted - </td> - <td>@Utils.memoryMegabytesToString(job.desc.memoryPerSlave)</td> - <td>@formatDate(job.submitDate)</td> - <td>@job.desc.user</td> - <td>@job.state.toString()</td> - <td>@formatDuration(job.duration)</td> -</tr> diff --git a/core/src/main/twirl/spark/deploy/worker/executor_row.scala.html b/core/src/main/twirl/spark/deploy/worker/executor_row.scala.html index ea9542461e5f736b41829f9bea8c07521dcc79f9..dad0a89080f2528848190b6803e7727d5bf932b8 100644 --- a/core/src/main/twirl/spark/deploy/worker/executor_row.scala.html +++ b/core/src/main/twirl/spark/deploy/worker/executor_row.scala.html @@ -8,13 +8,13 @@ <td>@Utils.memoryMegabytesToString(executor.memory)</td> <td> <ul class="unstyled"> - <li><strong>ID:</strong> @executor.jobId</li> - <li><strong>Name:</strong> @executor.jobDesc.name</li> - <li><strong>User:</strong> @executor.jobDesc.user</li> + <li><strong>ID:</strong> @executor.appId</li> + <li><strong>Name:</strong> @executor.appDesc.name</li> + <li><strong>User:</strong> @executor.appDesc.user</li> </ul> </td> <td> - <a href="log?jobId=@(executor.jobId)&executorId=@(executor.execId)&logType=stdout">stdout</a> - <a href="log?jobId=@(executor.jobId)&executorId=@(executor.execId)&logType=stderr">stderr</a> + <a href="log?appId=@(executor.appId)&executorId=@(executor.execId)&logType=stdout">stdout</a> + <a href="log?appId=@(executor.appId)&executorId=@(executor.execId)&logType=stderr">stderr</a> </td> </tr> diff --git a/core/src/main/twirl/spark/deploy/worker/index.scala.html b/core/src/main/twirl/spark/deploy/worker/index.scala.html index 1d703dae58ccc3b5621f82e12c915648e38423f2..c39f769a7387f96113c3f088a88fd6d1ac19351e 100644 --- a/core/src/main/twirl/spark/deploy/worker/index.scala.html +++ b/core/src/main/twirl/spark/deploy/worker/index.scala.html @@ -1,8 +1,8 @@ @(worker: spark.deploy.WorkerState) @import spark.Utils -@spark.common.html.layout(title = "Spark Worker on " + worker.uri) { - +@spark.common.html.layout(title = "Spark Worker on " + worker.host) { + <!-- Worker Details --> <div class="row"> <div class="span12"> @@ -10,12 +10,12 @@ <li><strong>ID:</strong> @worker.workerId</li> <li><strong> Master URL:</strong> @worker.masterUrl - (WebUI at <a href="@worker.masterWebUiUrl">@worker.masterWebUiUrl</a>) </li> <li><strong>Cores:</strong> @worker.cores (@worker.coresUsed Used)</li> <li><strong>Memory:</strong> @{Utils.memoryMegabytesToString(worker.memory)} (@{Utils.memoryMegabytesToString(worker.memoryUsed)} Used)</li> </ul> + <p><a href="@worker.masterWebUiUrl">Back to Master</a></p> </div> </div> diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 0b74607fb85a6a5d0456b58744eba49bc1f98585..ca385972fb2ebe3add71881f18c3edc9761077d8 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -34,7 +34,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { testCheckpointing(_.sample(false, 0.5, 0)) testCheckpointing(_.glom()) testCheckpointing(_.mapPartitions(_.map(_.toString))) - testCheckpointing(r => new MapPartitionsWithSplitRDD(r, + testCheckpointing(r => new MapPartitionsWithIndexRDD(r, (i: Int, iter: Iterator[Int]) => iter.map(_.toString), false )) testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) @@ -43,14 +43,14 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { test("ParallelCollection") { val parCollection = sc.makeRDD(1 to 4, 2) - val numSplits = parCollection.splits.size + val numPartitions = parCollection.partitions.size parCollection.checkpoint() assert(parCollection.dependencies === Nil) val result = parCollection.collect() assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) - assert(parCollection.splits.length === numSplits) - assert(parCollection.splits.toList === parCollection.checkpointData.get.getSplits.toList) + assert(parCollection.partitions.length === numPartitions) + assert(parCollection.partitions.toList === parCollection.checkpointData.get.getPartitions.toList) assert(parCollection.collect() === result) } @@ -59,13 +59,13 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { val blockManager = SparkEnv.get.blockManager blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY) val blockRDD = new BlockRDD[String](sc, Array(blockId)) - val numSplits = blockRDD.splits.size + val numPartitions = blockRDD.partitions.size blockRDD.checkpoint() val result = blockRDD.collect() assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result) assert(blockRDD.dependencies != Nil) - assert(blockRDD.splits.length === numSplits) - assert(blockRDD.splits.toList === blockRDD.checkpointData.get.getSplits.toList) + assert(blockRDD.partitions.length === numPartitions) + assert(blockRDD.partitions.toList === blockRDD.checkpointData.get.getPartitions.toList) assert(blockRDD.collect() === result) } @@ -79,9 +79,9 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { test("UnionRDD") { def otherRDD = sc.makeRDD(1 to 10, 1) - // Test whether the size of UnionRDDSplits reduce in size after parent RDD is checkpointed. + // Test whether the size of UnionRDDPartitions reduce in size after parent RDD is checkpointed. // Current implementation of UnionRDD has transient reference to parent RDDs, - // so only the splits will reduce in serialized size, not the RDD. + // so only the partitions will reduce in serialized size, not the RDD. testCheckpointing(_.union(otherRDD), false, true) testParentCheckpointing(_.union(otherRDD), false, true) } @@ -91,21 +91,21 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { testCheckpointing(new CartesianRDD(sc, _, otherRDD)) // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed - // Current implementation of CoalescedRDDSplit has transient reference to parent RDD, - // so only the RDD will reduce in serialized size, not the splits. + // Current implementation of CoalescedRDDPartition has transient reference to parent RDD, + // so only the RDD will reduce in serialized size, not the partitions. testParentCheckpointing(new CartesianRDD(sc, _, otherRDD), true, false) - // Test that the CartesianRDD updates parent splits (CartesianRDD.s1/s2) after - // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits. + // Test that the CartesianRDD updates parent partitions (CartesianRDD.s1/s2) after + // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions. // Note that this test is very specific to the current implementation of CartesianRDD. val ones = sc.makeRDD(1 to 100, 10).map(x => x) ones.checkpoint() // checkpoint that MappedRDD val cartesian = new CartesianRDD(sc, ones, ones) val splitBeforeCheckpoint = - serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit]) + serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition]) cartesian.count() // do the checkpointing val splitAfterCheckpoint = - serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit]) + serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition]) assert( (splitAfterCheckpoint.s1 != splitBeforeCheckpoint.s1) && (splitAfterCheckpoint.s2 != splitBeforeCheckpoint.s2), @@ -114,27 +114,27 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { } test("CoalescedRDD") { - testCheckpointing(new CoalescedRDD(_, 2)) + testCheckpointing(_.coalesce(2)) // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed - // Current implementation of CoalescedRDDSplit has transient reference to parent RDD, - // so only the RDD will reduce in serialized size, not the splits. - testParentCheckpointing(new CoalescedRDD(_, 2), true, false) + // Current implementation of CoalescedRDDPartition has transient reference to parent RDD, + // so only the RDD will reduce in serialized size, not the partitions. + testParentCheckpointing(_.coalesce(2), true, false) - // Test that the CoalescedRDDSplit updates parent splits (CoalescedRDDSplit.parents) after - // the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits. - // Note that this test is very specific to the current implementation of CoalescedRDDSplits + // Test that the CoalescedRDDPartition updates parent partitions (CoalescedRDDPartition.parents) after + // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions. + // Note that this test is very specific to the current implementation of CoalescedRDDPartitions val ones = sc.makeRDD(1 to 100, 10).map(x => x) ones.checkpoint() // checkpoint that MappedRDD val coalesced = new CoalescedRDD(ones, 2) val splitBeforeCheckpoint = - serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit]) + serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition]) coalesced.count() // do the checkpointing val splitAfterCheckpoint = - serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit]) + serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition]) assert( splitAfterCheckpoint.parents.head != splitBeforeCheckpoint.parents.head, - "CoalescedRDDSplit.parents not updated after parent RDD checkpointed" + "CoalescedRDDPartition.parents not updated after parent RDD checkpointed" ) } @@ -156,30 +156,40 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false) // Test whether size of ZippedRDD reduce in size after parent RDD is checkpointed - // Current implementation of ZippedRDDSplit has transient references to parent RDDs, - // so only the RDD will reduce in serialized size, not the splits. + // Current implementation of ZippedRDDPartitions has transient references to parent RDDs, + // so only the RDD will reduce in serialized size, not the partitions. testParentCheckpointing( rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false) } + test("CheckpointRDD with zero partitions") { + val rdd = new BlockRDD[Int](sc, Array[String]()) + assert(rdd.partitions.size === 0) + assert(rdd.isCheckpointed === false) + rdd.checkpoint() + assert(rdd.count() === 0) + assert(rdd.isCheckpointed === true) + assert(rdd.partitions.size === 0) + } + /** * Test checkpointing of the final RDD generated by the given operation. By default, * this method tests whether the size of serialized RDD has reduced after checkpointing or not. - * It can also test whether the size of serialized RDD splits has reduced after checkpointing or - * not, but this is not done by default as usually the splits do not refer to any RDD and + * It can also test whether the size of serialized RDD partitions has reduced after checkpointing or + * not, but this is not done by default as usually the partitions do not refer to any RDD and * therefore never store the lineage. */ def testCheckpointing[U: ClassManifest]( op: (RDD[Int]) => RDD[U], testRDDSize: Boolean = true, - testRDDSplitSize: Boolean = false + testRDDPartitionSize: Boolean = false ) { // Generate the final RDD using given RDD operation val baseRDD = generateLongLineageRDD() val operatedRDD = op(baseRDD) val parentRDD = operatedRDD.dependencies.headOption.orNull val rddType = operatedRDD.getClass.getSimpleName - val numSplits = operatedRDD.splits.length + val numPartitions = operatedRDD.partitions.length // Find serialized sizes before and after the checkpoint val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) @@ -193,11 +203,11 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { // Test whether dependencies have been changed from its earlier parent RDD assert(operatedRDD.dependencies.head.rdd != parentRDD) - // Test whether the splits have been changed to the new Hadoop splits - assert(operatedRDD.splits.toList === operatedRDD.checkpointData.get.getSplits.toList) + // Test whether the partitions have been changed to the new Hadoop partitions + assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList) - // Test whether the number of splits is same as before - assert(operatedRDD.splits.length === numSplits) + // Test whether the number of partitions is same as before + assert(operatedRDD.partitions.length === numPartitions) // Test whether the data in the checkpointed RDD is same as original assert(operatedRDD.collect() === result) @@ -215,18 +225,18 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { ) } - // Test whether serialized size of the splits has reduced. If the splits - // do not have any non-transient reference to another RDD or another RDD's splits, it + // Test whether serialized size of the partitions has reduced. If the partitions + // do not have any non-transient reference to another RDD or another RDD's partitions, it // does not refer to a lineage and therefore may not reduce in size after checkpointing. - // However, if the original splits before checkpointing do refer to a parent RDD, the splits + // However, if the original partitions before checkpointing do refer to a parent RDD, the partitions // must be forgotten after checkpointing (to remove all reference to parent RDDs) and - // replaced with the HadoopSplits of the checkpointed RDD. - if (testRDDSplitSize) { - logInfo("Size of " + rddType + " splits " + // replaced with the HadooPartitions of the checkpointed RDD. + if (testRDDPartitionSize) { + logInfo("Size of " + rddType + " partitions " + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]") assert( splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint, - "Size of " + rddType + " splits did not reduce after checkpointing " + + "Size of " + rddType + " partitions did not reduce after checkpointing " + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]" ) } @@ -235,13 +245,13 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { /** * Test whether checkpointing of the parent of the generated RDD also * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent - * RDDs splits. So even if the parent RDD is checkpointed and its splits changed, - * this RDD will remember the splits and therefore potentially the whole lineage. + * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, + * this RDD will remember the partitions and therefore potentially the whole lineage. */ def testParentCheckpointing[U: ClassManifest]( op: (RDD[Int]) => RDD[U], testRDDSize: Boolean, - testRDDSplitSize: Boolean + testRDDPartitionSize: Boolean ) { // Generate the final RDD using given RDD operation val baseRDD = generateLongLineageRDD() @@ -250,9 +260,9 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { val rddType = operatedRDD.getClass.getSimpleName val parentRDDType = parentRDD.getClass.getSimpleName - // Get the splits and dependencies of the parent in case they're lazily computed + // Get the partitions and dependencies of the parent in case they're lazily computed parentRDD.dependencies - parentRDD.splits + parentRDD.partitions // Find serialized sizes before and after the checkpoint val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD) @@ -275,16 +285,16 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { ) } - // Test whether serialized size of the splits has reduced because of its parent being - // checkpointed. If the splits do not have any non-transient reference to another RDD - // or another RDD's splits, it does not refer to a lineage and therefore may not reduce - // in size after checkpointing. However, if the splits do refer to the *splits* of a parent - // RDD, then these splits must update reference to the parent RDD splits as the parent RDD's - // splits must have changed after checkpointing. - if (testRDDSplitSize) { + // Test whether serialized size of the partitions has reduced because of its parent being + // checkpointed. If the partitions do not have any non-transient reference to another RDD + // or another RDD's partitions, it does not refer to a lineage and therefore may not reduce + // in size after checkpointing. However, if the partitions do refer to the *partitions* of a parent + // RDD, then these partitions must update reference to the parent RDD partitions as the parent RDD's + // partitions must have changed after checkpointing. + if (testRDDPartitionSize) { assert( splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint, - "Size of " + rddType + " splits did not reduce after checkpointing parent " + parentRDDType + + "Size of " + rddType + " partitions did not reduce after checkpointing parent " + parentRDDType + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]" ) } @@ -321,12 +331,12 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { } /** - * Get serialized sizes of the RDD and its splits, in order to test whether the size shrinks + * Get serialized sizes of the RDD and its partitions, in order to test whether the size shrinks * upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint. */ def getSerializedSizes(rdd: RDD[_]): (Int, Int) = { (Utils.serialize(rdd).length - Utils.serialize(rdd.checkpointData).length, - Utils.serialize(rdd.splits).length) + Utils.serialize(rdd.partitions).length) } /** @@ -347,7 +357,7 @@ object CheckpointSuite { def cogroup[K, V](first: RDD[(K, V)], second: RDD[(K, V)], part: Partitioner) = { //println("First = " + first + ", second = " + second) new CoGroupedRDD[K]( - Seq(first.asInstanceOf[RDD[(_, _)]], second.asInstanceOf[RDD[(_, _)]]), + Seq(first.asInstanceOf[RDD[(K, _)]], second.asInstanceOf[RDD[(K, _)]]), part ).asInstanceOf[RDD[(K, Seq[Seq[V]])]] } diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index 0e2585daa434cdad2c8deb14bd088072a69dfe31..caa4ba3a3705af5587099951c1914a03663b5ea8 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -217,6 +217,27 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter assert(grouped.collect.size === 1) } } + + test("recover from node failures with replication") { + import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + DistributedSuite.amMaster = true + // Using more than two nodes so we don't have a symmetric communication pattern and might + // cache a partially correct list of peers. + sc = new SparkContext("local-cluster[3,1,512]", "test") + for (i <- 1 to 3) { + val data = sc.parallelize(Seq(true, false, false, false), 4) + data.persist(StorageLevel.MEMORY_ONLY_2) + + assert(data.count === 4) + assert(data.map(markNodeIfIdentity).collect.size === 4) + assert(data.map(failOnMarkedIdentity).collect.size === 4) + + // Create a new replicated RDD to make sure that cached peer information doesn't cause + // problems. + val data2 = sc.parallelize(Seq(true, true), 2).persist(StorageLevel.MEMORY_ONLY_2) + assert(data2.count === 2) + } + } } object DistributedSuite { diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 934e4c2f6793bda90335d561b3792080a8490f38..9ffe7c5f992b6dfe82e77d596bcc81bea164be16 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -696,4 +696,28 @@ public class JavaAPISuite implements Serializable { JavaRDD<Integer> recovered = sc.checkpointFile(rdd.getCheckpointFile().get()); Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect()); } + + @Test + public void mapOnPairRDD() { + JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1,2,3,4)); + JavaPairRDD<Integer, Integer> rdd2 = rdd1.map(new PairFunction<Integer, Integer, Integer>() { + @Override + public Tuple2<Integer, Integer> call(Integer i) throws Exception { + return new Tuple2<Integer, Integer>(i, i % 2); + } + }); + JavaPairRDD<Integer, Integer> rdd3 = rdd2.map( + new PairFunction<Tuple2<Integer, Integer>, Integer, Integer>() { + @Override + public Tuple2<Integer, Integer> call(Tuple2<Integer, Integer> in) throws Exception { + return new Tuple2<Integer, Integer>(in._2(), in._1()); + } + }); + Assert.assertEquals(Arrays.asList( + new Tuple2<Integer, Integer>(1, 1), + new Tuple2<Integer, Integer>(0, 2), + new Tuple2<Integer, Integer>(1, 3), + new Tuple2<Integer, Integer>(0, 4)), rdd3.collect()); + + } } diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala index af1107cd197b4f7c6221bd66bf9beee9ec7270f8..60db759c25f3b424280e869df5319b21c2c3e7e6 100644 --- a/core/src/test/scala/spark/PartitioningSuite.scala +++ b/core/src/test/scala/spark/PartitioningSuite.scala @@ -84,10 +84,10 @@ class PartitioningSuite extends FunSuite with LocalSparkContext { assert(grouped4.groupByKey(3).partitioner != grouped4.partitioner) assert(grouped4.groupByKey(4).partitioner === grouped4.partitioner) - assert(grouped2.join(grouped4).partitioner === grouped2.partitioner) - assert(grouped2.leftOuterJoin(grouped4).partitioner === grouped2.partitioner) - assert(grouped2.rightOuterJoin(grouped4).partitioner === grouped2.partitioner) - assert(grouped2.cogroup(grouped4).partitioner === grouped2.partitioner) + assert(grouped2.join(grouped4).partitioner === grouped4.partitioner) + assert(grouped2.leftOuterJoin(grouped4).partitioner === grouped4.partitioner) + assert(grouped2.rightOuterJoin(grouped4).partitioner === grouped4.partitioner) + assert(grouped2.cogroup(grouped4).partitioner === grouped4.partitioner) assert(grouped2.join(reduced2).partitioner === grouped2.partitioner) assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner) diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index fe7deb10d63b001ca5a3db3e9398dcb38c233126..9739ba869b3182d0fbdc17ae3e9aaa63d859882c 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -33,6 +33,11 @@ class RDDSuite extends FunSuite with LocalSparkContext { } assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7))) + val partitionSumsWithIndex = nums.mapPartitionsWithIndex { + case(split, iter) => Iterator((split, iter.reduceLeft(_ + _))) + } + assert(partitionSumsWithIndex.collect().toList === List((0, 3), (1, 7))) + intercept[UnsupportedOperationException] { nums.filter(_ > 5).reduce(_ + _) } @@ -97,12 +102,12 @@ class RDDSuite extends FunSuite with LocalSparkContext { test("caching with failures") { sc = new SparkContext("local", "test") - val onlySplit = new Split { override def index: Int = 0 } + val onlySplit = new Partition { override def index: Int = 0 } var shouldFail = true val rdd = new RDD[Int](sc, Nil) { - override def getSplits: Array[Split] = Array(onlySplit) + override def getPartitions: Array[Partition] = Array(onlySplit) override val getDependencies = List[Dependency[_]]() - override def compute(split: Split, context: TaskContext): Iterator[Int] = { + override def compute(split: Partition, context: TaskContext): Iterator[Int] = { if (shouldFail) { throw new Exception("injected failure") } else { @@ -122,7 +127,7 @@ class RDDSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local", "test") val data = sc.parallelize(1 to 10, 10) - val coalesced1 = new CoalescedRDD(data, 2) + val coalesced1 = data.coalesce(2) assert(coalesced1.collect().toList === (1 to 10).toList) assert(coalesced1.glom().collect().map(_.toList).toList === List(List(1, 2, 3, 4, 5), List(6, 7, 8, 9, 10))) @@ -133,19 +138,19 @@ class RDDSuite extends FunSuite with LocalSparkContext { assert(coalesced1.dependencies.head.asInstanceOf[NarrowDependency[_]].getParents(1).toList === List(5, 6, 7, 8, 9)) - val coalesced2 = new CoalescedRDD(data, 3) + val coalesced2 = data.coalesce(3) assert(coalesced2.collect().toList === (1 to 10).toList) assert(coalesced2.glom().collect().map(_.toList).toList === List(List(1, 2, 3), List(4, 5, 6), List(7, 8, 9, 10))) - val coalesced3 = new CoalescedRDD(data, 10) + val coalesced3 = data.coalesce(10) assert(coalesced3.collect().toList === (1 to 10).toList) assert(coalesced3.glom().collect().map(_.toList).toList === (1 to 10).map(x => List(x)).toList) // If we try to coalesce into more partitions than the original RDD, it should just // keep the original number of partitions. - val coalesced4 = new CoalescedRDD(data, 20) + val coalesced4 = data.coalesce(20) assert(coalesced4.collect().toList === (1 to 10).toList) assert(coalesced4.glom().collect().map(_.toList).toList === (1 to 10).map(x => List(x)).toList) @@ -168,7 +173,7 @@ class RDDSuite extends FunSuite with LocalSparkContext { val data = sc.parallelize(1 to 10, 10) // Note that split number starts from 0, so > 8 means only 10th partition left. val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8) - assert(prunedRdd.splits.size === 1) + assert(prunedRdd.partitions.size === 1) val prunedData = prunedRdd.collect() assert(prunedData.size === 1) assert(prunedData(0) === 10) diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 3493b9511f6c2921d94ca1c7505ec9a436487dd5..8411291b2caa31e86f58bbf17f39fbf68020a669 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -1,6 +1,7 @@ package spark import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashSet import org.scalatest.FunSuite import org.scalatest.matchers.ShouldMatchers @@ -98,6 +99,28 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { val sums = pairs.reduceByKey(_+_, 10).collect() assert(sums.toSet === Set((1, 7), (2, 1))) } + + test("reduceByKey with partitioner") { + sc = new SparkContext("local", "test") + val p = new Partitioner() { + def numPartitions = 2 + def getPartition(key: Any) = key.asInstanceOf[Int] + } + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p) + val sums = pairs.reduceByKey(_+_) + assert(sums.collect().toSet === Set((1, 4), (0, 1))) + assert(sums.partitioner === Some(p)) + // count the dependencies to make sure there is only 1 ShuffledRDD + val deps = new HashSet[RDD[_]]() + def visit(r: RDD[_]) { + for (dep <- r.dependencies) { + deps += dep.rdd + visit(dep.rdd) + } + } + visit(sums) + assert(deps.size === 2) // ShuffledRDD, ParallelCollection + } test("join") { sc = new SparkContext("local", "test") @@ -199,7 +222,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { sc = new SparkContext("local", "test") val emptyDir = Files.createTempDir() val file = sc.textFile(emptyDir.getAbsolutePath) - assert(file.splits.size == 0) + assert(file.partitions.size == 0) assert(file.collect().toList === Nil) // Test that a shuffle on the file works, because this used to be a bug assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) @@ -211,6 +234,51 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { assert(rdd.keys.collect().toList === List(1, 2)) assert(rdd.values.collect().toList === List("a", "b")) } + + test("default partitioner uses partition size") { + sc = new SparkContext("local", "test") + // specify 2000 partitions + val a = sc.makeRDD(Array(1, 2, 3, 4), 2000) + // do a map, which loses the partitioner + val b = a.map(a => (a, (a * 2).toString)) + // then a group by, and see we didn't revert to 2 partitions + val c = b.groupByKey() + assert(c.partitions.size === 2000) + } + + test("default partitioner uses largest partitioner") { + sc = new SparkContext("local", "test") + val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2) + val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000) + val c = a.join(b) + assert(c.partitions.size === 2000) + } + + test("subtract") { + sc = new SparkContext("local", "test") + val a = sc.parallelize(Array(1, 2, 3), 2) + val b = sc.parallelize(Array(2, 3, 4), 4) + val c = a.subtract(b) + assert(c.collect().toSet === Set(1)) + assert(c.partitions.size === a.partitions.size) + } + + test("subtract with narrow dependency") { + sc = new SparkContext("local", "test") + // use a deterministic partitioner + val p = new Partitioner() { + def numPartitions = 5 + def getPartition(key: Any) = key.asInstanceOf[Int] + } + // partitionBy so we have a narrow dependency + val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p) + println(sc.runJob(a, (i: Iterator[(Int, String)]) => i.toList).toList) + // more partitions/no partitioner so a shuffle dependency + val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) + val c = a.subtract(b) + assert(c.collect().toSet === Set((1, "a"), (3, "c"))) + assert(c.partitioner.get === p) + } } object ShuffleSuite { diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala index edb8c839fcb708fedd6ddf90e66b906840792dc7..495f957e53f251a9597782530f120aa4e055b08f 100644 --- a/core/src/test/scala/spark/SortingSuite.scala +++ b/core/src/test/scala/spark/SortingSuite.scala @@ -19,7 +19,7 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } val pairs = sc.parallelize(pairArr, 2) val sorted = pairs.sortByKey() - assert(sorted.splits.size === 2) + assert(sorted.partitions.size === 2) assert(sorted.collect() === pairArr.sortBy(_._1)) } @@ -29,17 +29,17 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } val pairs = sc.parallelize(pairArr, 2) val sorted = pairs.sortByKey(true, 1) - assert(sorted.splits.size === 1) + assert(sorted.partitions.size === 1) assert(sorted.collect() === pairArr.sortBy(_._1)) } - test("large array with many splits") { + test("large array with many partitions") { sc = new SparkContext("local", "test") val rand = new scala.util.Random() val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } val pairs = sc.parallelize(pairArr, 2) val sorted = pairs.sortByKey(true, 20) - assert(sorted.splits.size === 20) + assert(sorted.partitions.size === 20) assert(sorted.collect() === pairArr.sortBy(_._1)) } @@ -59,7 +59,7 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1)) } - test("sort descending with many splits") { + test("sort descending with many partitions") { sc = new SparkContext("local", "test") val rand = new scala.util.Random() val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) } diff --git a/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala b/core/src/test/scala/spark/rdd/ParallelCollectionSplitSuite.scala similarity index 83% rename from core/src/test/scala/spark/ParallelCollectionSplitSuite.scala rename to core/src/test/scala/spark/rdd/ParallelCollectionSplitSuite.scala index 450c69bd58adb21fa246f2f6e195c9ff7e1a89b5..d27a2538e4489644a514ce677edf1ee4c29682c8 100644 --- a/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala +++ b/core/src/test/scala/spark/rdd/ParallelCollectionSplitSuite.scala @@ -1,4 +1,4 @@ -package spark +package spark.rdd import scala.collection.immutable.NumericRange @@ -11,7 +11,7 @@ import org.scalacheck.Prop._ class ParallelCollectionSplitSuite extends FunSuite with Checkers { test("one element per slice") { val data = Array(1, 2, 3) - val slices = ParallelCollection.slice(data, 3) + val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) assert(slices(0).mkString(",") === "1") assert(slices(1).mkString(",") === "2") @@ -20,14 +20,14 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { test("one slice") { val data = Array(1, 2, 3) - val slices = ParallelCollection.slice(data, 1) + val slices = ParallelCollectionRDD.slice(data, 1) assert(slices.size === 1) assert(slices(0).mkString(",") === "1,2,3") } test("equal slices") { val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9) - val slices = ParallelCollection.slice(data, 3) + val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) assert(slices(0).mkString(",") === "1,2,3") assert(slices(1).mkString(",") === "4,5,6") @@ -36,7 +36,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { test("non-equal slices") { val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) - val slices = ParallelCollection.slice(data, 3) + val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) assert(slices(0).mkString(",") === "1,2,3") assert(slices(1).mkString(",") === "4,5,6") @@ -45,7 +45,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { test("splitting exclusive range") { val data = 0 until 100 - val slices = ParallelCollection.slice(data, 3) + val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) assert(slices(0).mkString(",") === (0 to 32).mkString(",")) assert(slices(1).mkString(",") === (33 to 65).mkString(",")) @@ -54,7 +54,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { test("splitting inclusive range") { val data = 0 to 100 - val slices = ParallelCollection.slice(data, 3) + val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) assert(slices(0).mkString(",") === (0 to 32).mkString(",")) assert(slices(1).mkString(",") === (33 to 66).mkString(",")) @@ -63,24 +63,24 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { test("empty data") { val data = new Array[Int](0) - val slices = ParallelCollection.slice(data, 5) + val slices = ParallelCollectionRDD.slice(data, 5) assert(slices.size === 5) for (slice <- slices) assert(slice.size === 0) } test("zero slices") { val data = Array(1, 2, 3) - intercept[IllegalArgumentException] { ParallelCollection.slice(data, 0) } + intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, 0) } } test("negative number of slices") { val data = Array(1, 2, 3) - intercept[IllegalArgumentException] { ParallelCollection.slice(data, -5) } + intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, -5) } } test("exclusive ranges sliced into ranges") { val data = 1 until 100 - val slices = ParallelCollection.slice(data, 3) + val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) assert(slices.map(_.size).reduceLeft(_+_) === 99) assert(slices.forall(_.isInstanceOf[Range])) @@ -88,7 +88,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { test("inclusive ranges sliced into ranges") { val data = 1 to 100 - val slices = ParallelCollection.slice(data, 3) + val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) assert(slices.map(_.size).reduceLeft(_+_) === 100) assert(slices.forall(_.isInstanceOf[Range])) @@ -97,7 +97,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { test("large ranges don't overflow") { val N = 100 * 1000 * 1000 val data = 0 until N - val slices = ParallelCollection.slice(data, 40) + val slices = ParallelCollectionRDD.slice(data, 40) assert(slices.size === 40) for (i <- 0 until 40) { assert(slices(i).isInstanceOf[Range]) @@ -117,7 +117,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { (tuple: (List[Int], Int)) => val d = tuple._1 val n = tuple._2 - val slices = ParallelCollection.slice(d, n) + val slices = ParallelCollectionRDD.slice(d, n) ("n slices" |: slices.size == n) && ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && ("equal sizes" |: slices.map(_.size).forall(x => x==d.size/n || x==d.size/n+1)) @@ -134,7 +134,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { } yield (a until b by step, n) val prop = forAll(gen) { case (d: Range, n: Int) => - val slices = ParallelCollection.slice(d, n) + val slices = ParallelCollectionRDD.slice(d, n) ("n slices" |: slices.size == n) && ("all ranges" |: slices.forall(_.isInstanceOf[Range])) && ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && @@ -152,7 +152,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { } yield (a to b by step, n) val prop = forAll(gen) { case (d: Range, n: Int) => - val slices = ParallelCollection.slice(d, n) + val slices = ParallelCollectionRDD.slice(d, n) ("n slices" |: slices.size == n) && ("all ranges" |: slices.forall(_.isInstanceOf[Range])) && ("concat to d" |: Seq.concat(slices: _*).mkString(",") == d.mkString(",")) && @@ -163,7 +163,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { test("exclusive ranges of longs") { val data = 1L until 100L - val slices = ParallelCollection.slice(data, 3) + val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) assert(slices.map(_.size).reduceLeft(_+_) === 99) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) @@ -171,7 +171,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { test("inclusive ranges of longs") { val data = 1L to 100L - val slices = ParallelCollection.slice(data, 3) + val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) assert(slices.map(_.size).reduceLeft(_+_) === 100) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) @@ -179,7 +179,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { test("exclusive ranges of doubles") { val data = 1.0 until 100.0 by 1.0 - val slices = ParallelCollection.slice(data, 3) + val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) assert(slices.map(_.size).reduceLeft(_+_) === 99) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) @@ -187,7 +187,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { test("inclusive ranges of doubles") { val data = 1.0 to 100.0 by 1.0 - val slices = ParallelCollection.slice(data, 3) + val slices = ParallelCollectionRDD.slice(data, 3) assert(slices.size === 3) assert(slices.map(_.size).reduceLeft(_+_) === 100) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala index edc5a7dfbaa86f3f0a3a89af96090d9cb2afdeab..07cccc7ce087418ce90c4b1b9b5a9cb9256b4bf5 100644 --- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala @@ -14,7 +14,7 @@ import spark.MapOutputTracker import spark.RDD import spark.SparkContext import spark.SparkException -import spark.Split +import spark.Partition import spark.TaskContext import spark.TaskEndReason @@ -111,18 +111,18 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter { * so we can test that DAGScheduler does not try to execute RDDs locally. */ private def makeRdd( - numSplits: Int, + numPartitions: Int, dependencies: List[Dependency[_]], locations: Seq[Seq[String]] = Nil ): MyRDD = { - val maxSplit = numSplits - 1 + val maxPartition = numPartitions - 1 return new MyRDD(sc, dependencies) { - override def compute(split: Split, context: TaskContext): Iterator[(Int, Int)] = + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = throw new RuntimeException("should not be reached") - override def getSplits() = (0 to maxSplit).map(i => new Split { + override def getPartitions = (0 to maxPartition).map(i => new Partition { override def index = i }).toArray - override def getPreferredLocations(split: Split): Seq[String] = + override def getPreferredLocations(split: Partition): Seq[String] = if (locations.isDefinedAt(split.index)) locations(split.index) else @@ -196,9 +196,10 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter { test("local job") { val rdd = new MyRDD(sc, Nil) { - override def compute(split: Split, context: TaskContext) = Array(42 -> 0).iterator - override def getSplits() = Array(new Split { override def index = 0 }) - override def getPreferredLocations(split: Split) = Nil + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = + Array(42 -> 0).iterator + override def getPartitions = Array( new Partition { override def index = 0 } ) + override def getPreferredLocations(split: Partition) = Nil override def toString = "DAGSchedulerSuite Local RDD" } runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener)) @@ -397,4 +398,4 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter { private def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) -} \ No newline at end of file +} diff --git a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala index a5db7103f5ce8cd28a7ce989f9f5fc17666efc17..647bcaf860a37c83354493a805d9aa06aa7f9dba 100644 --- a/core/src/test/scala/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/spark/scheduler/TaskContextSuite.scala @@ -5,7 +5,7 @@ import org.scalatest.BeforeAndAfter import spark.TaskContext import spark.RDD import spark.SparkContext -import spark.Split +import spark.Partition import spark.LocalSparkContext class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { @@ -14,8 +14,8 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte var completed = false sc = new SparkContext("local", "test") val rdd = new RDD[String](sc, List()) { - override def getSplits = Array[Split](StubSplit(0)) - override def compute(split: Split, context: TaskContext) = { + override def getPartitions = Array[Partition](StubPartition(0)) + override def compute(split: Partition, context: TaskContext) = { context.addOnCompleteCallback(() => completed = true) sys.error("failed") } @@ -28,5 +28,5 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte assert(completed === true) } - case class StubSplit(val index: Int) extends Split -} \ No newline at end of file + case class StubPartition(val index: Int) extends Partition +} diff --git a/docs/_config.yml b/docs/_config.yml index 2bd2eecc863e40eba17372829dcb6e0ce5d108e9..09617e4a1efb6b3486970b7fc7715f1fb0993a16 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -7,3 +7,4 @@ SPARK_VERSION: 0.7.0-SNAPSHOT SPARK_VERSION_SHORT: 0.7.0 SCALA_VERSION: 2.9.2 MESOS_VERSION: 0.9.0-incubating +SPARK_ISSUE_TRACKER_URL: https://spark-project.atlassian.net diff --git a/docs/configuration.md b/docs/configuration.md index a7054b4321bcf499ecc22e35cfeaa42437ad8157..04eb6daaa5d016a6ecdaa51575011882bb328bb1 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -183,7 +183,7 @@ Apart from these, the following properties are also available, and may be useful </tr> <tr> <td>spark.broadcast.factory</td> - <td>spark.broadcast. HttpBroadcastFactory</td> + <td>spark.broadcast.HttpBroadcastFactory</td> <td> Which broadcast implementation to use. </td> @@ -197,6 +197,14 @@ Apart from these, the following properties are also available, and may be useful poor data locality, but the default generally works well. </td> </tr> +<tr> + <td>spark.worker.timeout</td> + <td>60</td> + <td> + Number of seconds after which the standalone deploy master considers a worker lost if it + receives no heartbeats. + </td> +</tr> <tr> <td>spark.akka.frameSize</td> <td>10</td> @@ -218,7 +226,7 @@ Apart from these, the following properties are also available, and may be useful <td>spark.akka.timeout</td> <td>20</td> <td> - Communication timeout between Spark nodes. + Communication timeout between Spark nodes, in seconds. </td> </tr> <tr> @@ -236,10 +244,10 @@ Apart from these, the following properties are also available, and may be useful </td> </tr> <tr> - <td>spark.cleaner.delay</td> + <td>spark.cleaner.ttl</td> <td>(disable)</td> <td> - Duration (minutes) of how long Spark will remember any metadata (stages generated, tasks generated, etc.). + Duration (seconds) of how long Spark will remember any metadata (stages generated, tasks generated, etc.). Periodic cleanups will ensure that metadata older than this duration will be forgetten. This is useful for running Spark for many hours / days (for example, running 24/7 in case of Spark Streaming applications). Note that any RDD that persists in memory for more than this duration will be cleared as well. diff --git a/docs/contributing-to-spark.md b/docs/contributing-to-spark.md index c6e01c62d808041400913a305f9346c4dd8968d3..50feeb2d6c42afa14e7092a7902d75e5362d570f 100644 --- a/docs/contributing-to-spark.md +++ b/docs/contributing-to-spark.md @@ -15,7 +15,7 @@ The Spark team welcomes contributions in the form of GitHub pull requests. Here But first, make sure that you have [configured a spark-env.sh](configuration.html) with at least `SCALA_HOME`, as some of the tests try to spawn subprocesses using this. - Add new unit tests for your code. We use [ScalaTest](http://www.scalatest.org/) for testing. Just add a new Suite in `core/src/test`, or methods to an existing Suite. -- If you'd like to report a bug but don't have time to fix it, you can still post it to our [issues page](https://github.com/mesos/spark/issues), or email the [mailing list](http://www.spark-project.org/mailing-lists.html). +- If you'd like to report a bug but don't have time to fix it, you can still post it to our [issue tracker]({{site.SPARK_ISSUE_TRACKER_URL}}), or email the [mailing list](http://www.spark-project.org/mailing-lists.html). # Licensing of Contributions diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index 4e84d23edf56b0a6b6fe3291744c269be0085b28..2012241a6a77bb1a0114b72242801c00a44aea14 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -87,7 +87,7 @@ By default, the `pyspark` shell creates SparkContext that runs jobs locally. To connect to a non-local cluster, set the `MASTER` environment variable. For example, to use the `pyspark` shell with a [standalone Spark cluster](spark-standalone.html): -{% highlight shell %} +{% highlight bash %} $ MASTER=spark://IP:PORT ./pyspark {% endhighlight %} diff --git a/docs/scala-programming-guide.md b/docs/scala-programming-guide.md index 301b330a79e933caedb0da365948f13bcb8e2d3c..b98718a5532e0d49e19437be39bc7482f7d6d328 100644 --- a/docs/scala-programming-guide.md +++ b/docs/scala-programming-guide.md @@ -203,7 +203,7 @@ A complete list of transformations is available in the [RDD API doc](api/core/in <tr><th>Action</th><th>Meaning</th></tr> <tr> <td> <b>reduce</b>(<i>func</i>) </td> - <td> Aggregate the elements of the dataset using a function <i>func</i> (which takes two arguments and returns one). The function should be associative so that it can be computed correctly in parallel. </td> + <td> Aggregate the elements of the dataset using a function <i>func</i> (which takes two arguments and returns one). The function should be commutative and associative so that it can be computed correctly in parallel. </td> </tr> <tr> <td> <b>collect</b>() </td> diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index bf296221b844776b5415af2dc6b6cb5b531e3e36..3986c0c79dceaecb0da46adbf07241b64229f650 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -115,6 +115,14 @@ You can optionally configure the cluster further by setting environment variable <td><code>SPARK_WORKER_WEBUI_PORT</code></td> <td>Port for the worker web UI (default: 8081)</td> </tr> + <tr> + <td><code>SPARK_DAEMON_MEMORY</code></td> + <td>Memory to allocate to the Spark master and worker daemons themselves (default: 512m)</td> + </tr> + <tr> + <td><code>SPARK_DAEMON_JAVA_OPTS</code></td> + <td>JVM options for the Spark master and worker daemons themselves (default: none)</td> + </tr> </table> diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md new file mode 100644 index 0000000000000000000000000000000000000000..5476c00d020cb2cfece90853d7e010d08055c00d --- /dev/null +++ b/docs/streaming-custom-receivers.md @@ -0,0 +1,101 @@ +--- +layout: global +title: Tutorial - Spark Streaming, Plugging in a custom receiver. +--- + +A "Spark Streaming" receiver can be a simple network stream, streams of messages from a message queue, files etc. A receiver can also assume roles more than just receiving data like filtering, preprocessing, to name a few of the possibilities. The api to plug-in any user defined custom receiver is thus provided to encourage development of receivers which may be well suited to ones specific need. + +This guide shows the programming model and features by walking through a simple sample receiver and corresponding Spark Streaming application. + + +## A quick and naive walk-through + +### Write a simple receiver + +This starts with implementing [Actor](#References) + +Following is a simple socket text-stream receiver, which is appearently overly simplified using Akka's socket.io api. + +{% highlight scala %} + + class SocketTextStreamReceiver (host:String, + port:Int, + bytesToString: ByteString => String) extends Actor with Receiver { + + override def preStart = IOManager(context.system).connect(host, port) + + def receive = { + case IO.Read(socket, bytes) => pushBlock(bytesToString(bytes)) + } + + } + + +{% endhighlight %} + +All we did here is mixed in trait Receiver and called pushBlock api method to push our blocks of data. Please refer to scala-docs of Receiver for more details. + +### A sample spark application + +* First create a Spark streaming context with master url and batchduration. + +{% highlight scala %} + + val ssc = new StreamingContext(master, "WordCountCustomStreamSource", + Seconds(batchDuration)) + +{% endhighlight %} + +* Plug-in the actor configuration into the spark streaming context and create a DStream. + +{% highlight scala %} + + val lines = ssc.actorStream[String](Props(new SocketTextStreamReceiver( + "localhost",8445, z => z.utf8String)),"SocketReceiver") + +{% endhighlight %} + +* Process it. + +{% highlight scala %} + + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + + wordCounts.print() + ssc.start() + + +{% endhighlight %} + +* After processing it, stream can be tested using the netcat utility. + + $ nc -l localhost 8445 + hello world + hello hello + + +## Multiple homogeneous/heterogeneous receivers. + +A DStream union operation is provided for taking union on multiple input streams. + +{% highlight scala %} + + val lines = ssc.actorStream[String](Props(new SocketTextStreamReceiver( + "localhost",8445, z => z.utf8String)),"SocketReceiver") + + // Another socket stream receiver + val lines2 = ssc.actorStream[String](Props(new SocketTextStreamReceiver( + "localhost",8446, z => z.utf8String)),"SocketReceiver") + + val union = lines.union(lines2) + +{% endhighlight %} + +Above stream can be easily process as described earlier. + +_A more comprehensive example is provided in the spark streaming examples_ + +## References + +1.[Akka Actor documentation](http://doc.akka.io/docs/akka/2.0.5/scala/actors.html) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index b6da7af654ee6de2e5ae5733faf1c69ad8a8092b..0e618a06c796fcb821e4f4aa4d46e274e5f48494 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -34,34 +34,34 @@ The StreamingContext is used to creating InputDStreams from input sources: {% highlight scala %} // Assuming ssc is the StreamingContext -ssc.networkStream(hostname, port) // Creates a stream that uses a TCP socket to read data from hostname:port -ssc.textFileStream(directory) // Creates a stream by monitoring and processing new files in a HDFS directory +ssc.textFileStream(directory) // Creates a stream by monitoring and processing new files in a HDFS directory +ssc.socketStream(hostname, port) // Creates a stream that uses a TCP socket to read data from hostname:port {% endhighlight %} -A complete list of input sources is available in the [StreamingContext API documentation](api/streaming/index.html#spark.streaming.StreamingContext). Data received from these sources can be processed using DStream operations, which are explained next. +We also provide a input streams for Kafka, Flume, Akka actor, etc. For a complete list of input streams, take a look at the [StreamingContext API documentation](api/streaming/index.html#spark.streaming.StreamingContext). # DStream Operations -Once an input DStream has been created, you can transform it using _DStream operators_. Most of these operators return new DStreams which you can further transform. Eventually, you'll need to call an _output operator_, which forces evaluation of the DStream by writing data out to an external source. +Data received from the input streams can be processed using _DStream operations_. There are two kinds of operations - _transformations_ and _output operations_. Similar to RDD transformations, DStream transformations operate on one or more DStreams to create new DStreams with transformed data. After applying a sequence of transformations to the input streams, you'll need to call the output operations, which writies data out to an external source. ## Transformations DStreams support many of the transformations available on normal Spark RDD's: <table class="table"> -<tr><th style="width:25%">Transformation</th><th>Meaning</th></tr> +<tr><th style="width:30%">Transformation</th><th>Meaning</th></tr> <tr> <td> <b>map</b>(<i>func</i>) </td> - <td> Returns a new DStream formed by passing each element of the source through a function <i>func</i>. </td> + <td> Returns a new DStream formed by passing each element of the source DStream through a function <i>func</i>. </td> </tr> <tr> <td> <b>filter</b>(<i>func</i>) </td> - <td> Returns a new stream formed by selecting those elements of the source on which <i>func</i> returns true. </td> + <td> Returns a new DStream formed by selecting those elements of the source DStream on which <i>func</i> returns true. </td> </tr> <tr> <td> <b>flatMap</b>(<i>func</i>) </td> - <td> Similar to map, but each input item can be mapped to 0 or more output items (so <i>func</i> should return a Seq rather than a single item). </td> + <td> Similar to map, but each input item can be mapped to 0 or more output items (so <i>func</i> should return a <code>Seq</code> rather than a single item). </td> </tr> <tr> <td> <b>mapPartitions</b>(<i>func</i>) </td> @@ -70,73 +70,92 @@ DStreams support many of the transformations available on normal Spark RDD's: </tr> <tr> <td> <b>union</b>(<i>otherStream</i>) </td> - <td> Return a new stream that contains the union of the elements in the source stream and the argument. </td> + <td> Return a new DStream that contains the union of the elements in the source DStream and the argument DStream. </td> +</tr> +<tr> + <td> <b>count</b>() </td> + <td> Returns a new DStream of single-element RDDs by counting the number of elements in each RDD of the source DStream. </td> +</tr> +<tr> + <td> <b>reduce</b>(<i>func</i>) </td> + <td> Returns a new DStream of single-element RDDs by aggregating the elements in each RDD of the source DStream using a function <i>func</i> (which takes two arguments and returns one). The function should be associative so that it can be computed in parallel. </td> +</tr> +<tr> + <td> <b>countByValue</b>() </td> + <td> When called on a DStream of elements of type K, returns a new DStream of (K, Long) pairs where the value of each key is its frequency in each RDD of the source DStream. </td> </tr> <tr> <td> <b>groupByKey</b>([<i>numTasks</i>]) </td> - <td> When called on a stream of (K, V) pairs, returns a stream of (K, Seq[V]) pairs. <br /> -<b>Note:</b> By default, this uses only 8 parallel tasks to do the grouping. You can pass an optional <code>numTasks</code> argument to set a different number of tasks. + <td> When called on a DStream of (K, V) pairs, returns a new DStream of (K, Seq[V]) pairs by grouping together all the values of each key in the RDDs of the source DStream. <br /> + <b>Note:</b> By default, this uses Spark's default number of parallel tasks (2 for local machine, 8 for a cluser) to do the grouping. You can pass an optional <code>numTasks</code> argument to set a different number of tasks. </td> </tr> <tr> <td> <b>reduceByKey</b>(<i>func</i>, [<i>numTasks</i>]) </td> - <td> When called on a stream of (K, V) pairs, returns a stream of (K, V) pairs where the values for each key are aggregated using the given reduce function. Like in <code>groupByKey</code>, the number of reduce tasks is configurable through an optional second argument. </td> + <td> When called on a DStream of (K, V) pairs, returns a new DStream of (K, V) pairs where the values for each key are aggregated using the given reduce function. Like in <code>groupByKey</code>, the number of reduce tasks is configurable through an optional second argument. </td> </tr> <tr> <td> <b>join</b>(<i>otherStream</i>, [<i>numTasks</i>]) </td> - <td> When called on streams of type (K, V) and (K, W), returns a stream of (K, (V, W)) pairs with all pairs of elements for each key. </td> + <td> When called on two DStreams of (K, V) and (K, W) pairs, returns a new DStream of (K, (V, W)) pairs with all pairs of elements for each key. </td> </tr> <tr> <td> <b>cogroup</b>(<i>otherStream</i>, [<i>numTasks</i>]) </td> - <td> When called on DStream of type (K, V) and (K, W), returns a DStream of (K, Seq[V], Seq[W]) tuples.</td> -</tr> -<tr> - <td> <b>reduce</b>(<i>func</i>) </td> - <td> Returns a new DStream of single-element RDDs by aggregating the elements of the stream using a function func (which takes two arguments and returns one). The function should be associative so that it can be computed correctly in parallel. </td> + <td> When called on DStream of (K, V) and (K, W) pairs, returns a new DStream of (K, Seq[V], Seq[W]) tuples.</td> </tr> <tr> <td> <b>transform</b>(<i>func</i>) </td> <td> Returns a new DStream by applying func (a RDD-to-RDD function) to every RDD of the stream. This can be used to do arbitrary RDD operations on the DStream. </td> </tr> +<tr> + <td> <b>updateStateByKey</b>(<i>func</i>) </td> + <td> Return a new "state" DStream where the state for each key is updated by applying the given function on the previous state of the key and the new values of each key. This can be used to track session state by using the session-id as the key and updating the session state as new data is received.</td> +</tr> + </table> -Spark Streaming features windowed computations, which allow you to report statistics over a sliding window of data. All window functions take a <i>windowDuration</i>, which represents the width of the window and a <i>slideTime</i>, which represents the frequency during which the window is calculated. +Spark Streaming features windowed computations, which allow you to apply transformations over a sliding window of data. All window functions take a <i>windowDuration</i>, which represents the width of the window and a <i>slideTime</i>, which represents the frequency during which the window is calculated. <table class="table"> -<tr><th style="width:25%">Transformation</th><th>Meaning</th></tr> +<tr><th style="width:30%">Transformation</th><th>Meaning</th></tr> <tr> - <td> <b>window</b>(<i>windowDuration</i>, </i>slideTime</i>) </td> - <td> Return a new stream which is computed based on windowed batches of the source stream. <i>windowDuration</i> is the width of the window and <i>slideTime</i> is the frequency during which the window is calculated. Both times must be multiples of the batch interval. + <td> <b>window</b>(<i>windowDuration</i>, </i>slideDuration</i>) </td> + <td> Return a new DStream which is computed based on windowed batches of the source DStream. <i>windowDuration</i> is the width of the window and <i>slideTime</i> is the frequency during which the window is calculated. Both times must be multiples of the batch interval. </td> </tr> <tr> - <td> <b>countByWindow</b>(<i>windowDuration</i>, </i>slideTime</i>) </td> + <td> <b>countByWindow</b>(<i>windowDuration</i>, </i>slideDuration</i>) </td> <td> Return a sliding count of elements in the stream. <i>windowDuration</i> and <i>slideDuration</i> are exactly as defined in <code>window()</code>. </td> </tr> <tr> - <td> <b>reduceByWindow</b>(<i>func</i>, <i>windowDuration</i>, </i>slideDuration</i>) </td> + <td> <b>reduceByWindow</b>(<i>func</i>, <i>windowDuration</i>, <i>slideDuration</i>) </td> <td> Return a new single-element stream, created by aggregating elements in the stream over a sliding interval using <i>func</i>. The function should be associative so that it can be computed correctly in parallel. <i>windowDuration</i> and <i>slideDuration</i> are exactly as defined in <code>window()</code>. </td> </tr> <tr> - <td> <b>groupByKeyAndWindow</b>(windowDuration, slideDuration, [<i>numTasks</i>]) + <td> <b>groupByKeyAndWindow</b>(<i>windowDuration</i>, <i>slideDuration</i>, [<i>numTasks</i>]) </td> - <td> When called on a stream of (K, V) pairs, returns a stream of (K, Seq[V]) pairs over a sliding window. <br /> -<b>Note:</b> By default, this uses only 8 parallel tasks to do the grouping. You can pass an optional <code>numTasks</code> argument to set a different number of tasks. <i>windowDuration</i> and <i>slideDuration</i> are exactly as defined in <code>window()</code>. -</td> + <td> When called on a DStream of (K, V) pairs, returns a new DStream of (K, Seq[V]) pairs by grouping together values of each key over batches in a sliding window. <br /> +<b>Note:</b> By default, this uses Spark's default number of parallel tasks (2 for local machine, 8 for a cluser) to do the grouping. You can pass an optional <code>numTasks</code> argument to set a different number of tasks.</td> </tr> <tr> - <td> <b>reduceByKeyAndWindow</b>(<i>func</i>, [<i>numTasks</i>]) </td> - <td> When called on a stream of (K, V) pairs, returns a stream of (K, V) pairs where the values for each key are aggregated using the given reduce function over batches within a sliding window. Like in <code>groupByKeyAndWindow</code>, the number of reduce tasks is configurable through an optional second argument. + <td> <b>reduceByKeyAndWindow</b>(<i>func</i>, <i>windowDuration</i>, <i>slideDuration</i>, [<i>numTasks</i>]) </td> + <td> When called on a DStream of (K, V) pairs, returns a new DStream of (K, V) pairs where the values for each key are aggregated using the given reduce function <i>func</i> over batches in a sliding window. Like in <code>groupByKeyAndWindow</code>, the number of reduce tasks is configurable through an optional second argument. <i>windowDuration</i> and <i>slideDuration</i> are exactly as defined in <code>window()</code>. </td> </tr> <tr> - <td> <b>countByKeyAndWindow</b>([<i>numTasks</i>]) </td> - <td> When called on a stream of (K, V) pairs, returns a stream of (K, Int) pairs where the values for each key are the count within a sliding window. Like in <code>countByKeyAndWindow</code>, the number of reduce tasks is configurable through an optional second argument. + <td> <b>reduceByKeyAndWindow</b>(<i>func</i>, <i>invFunc</i>, <i>windowDuration</i>, <i>slideDuration</i>, [<i>numTasks</i>]) </td> + <td> A more efficient version of the above <code>reduceByKeyAndWindow()</code> where the reduce value of each window is calculated + incrementally using the reduce values of the previous window. This is done by reducing the new data that enter the sliding window, and "inverse reducing" the old data that leave the window. An example would be that of "adding" and "subtracting" counts of keys as the window slides. However, it is applicable to only "invertible reduce functions", that is, those reduce functions which have a corresponding "inverse reduce" function (taken as parameter <i>invFunc</i>. Like in <code>groupByKeyAndWindow</code>, the number of reduce tasks is configurable through an optional second argument. <i>windowDuration</i> and <i>slideDuration</i> are exactly as defined in <code>window()</code>. -</td> +</td> +</tr> +<tr> + <td> <b>countByValueAndWindow</b>(<i>windowDuration</i>, <i>slideDuration</i>, [<i>numTasks</i>]) </td> + <td> When called on a DStream of (K, V) pairs, returns a new DStream of (K, Long) pairs where the value of each key is its frequency within a sliding window. Like in <code>groupByKeyAndWindow</code>, the number of reduce tasks is configurable through an optional second argument. + <i>windowDuration</i> and <i>slideDuration</i> are exactly as defined in <code>window()</code>. +</td> </tr> </table> @@ -147,7 +166,7 @@ A complete list of DStream operations is available in the API documentation of [ When an output operator is called, it triggers the computation of a stream. Currently the following output operators are defined: <table class="table"> -<tr><th style="width:25%">Operator</th><th>Meaning</th></tr> +<tr><th style="width:30%">Operator</th><th>Meaning</th></tr> <tr> <td> <b>foreach</b>(<i>func</i>) </td> <td> The fundamental output operator. Applies a function, <i>func</i>, to each RDD generated from the stream. This function should have side effects, such as printing output, saving the RDD to external files, or writing it over the network to an external system. </td> @@ -176,11 +195,6 @@ When an output operator is called, it triggers the computation of a stream. Curr </table> -## DStream Persistence -Similar to RDDs, DStreams also allow developers to persist the stream's data in memory. That is, using `persist()` method on a DStream would automatically persist every RDD of that DStream in memory. This is useful if the data in the DStream will be computed multiple times (e.g., multiple DStream operations on the same data). For window-based operations like `reduceByWindow` and `reduceByKeyAndWindow` and state-based operations like `updateStateByKey`, this is implicitly true. Hence, DStreams generated by window-based operations are automatically persisted in memory, without the developer calling `persist()`. - -Note that, unlike RDDs, the default persistence level of DStreams keeps the data serialized in memory. This is further discussed in the [Performance Tuning](#memory-tuning) section. More information on different persistence levels can be found in [Spark Programming Guide](scala-programming-guide.html#rdd-persistence). - # Starting the Streaming computation All the above DStream operations are completely lazy, that is, the operations will start executing only after the context is started by using {% highlight scala %} @@ -192,8 +206,8 @@ Conversely, the computation can be stopped by using ssc.stop() {% endhighlight %} -# Example - NetworkWordCount.scala -A good example to start off is the spark.streaming.examples.NetworkWordCount. This example counts the words received from a network server every second. Given below is the relevant sections of the source code. You can find the full source code in <Spark repo>/streaming/src/main/scala/spark/streaming/examples/WordCountNetwork.scala. +# Example +A simple example to start off is the [NetworkWordCount](https://github.com/mesos/spark/tree/master/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala). This example counts the words received from a network server every second. Given below is the relevant sections of the source code. You can find the full source code in `<Spark repo>/streaming/src/main/scala/spark/streaming/examples/NetworkWordCount.scala` . {% highlight scala %} import spark.streaming.{Seconds, StreamingContext} @@ -202,7 +216,7 @@ import spark.streaming.StreamingContext._ // Create the context and set up a network input stream to receive from a host:port val ssc = new StreamingContext(args(0), "NetworkWordCount", Seconds(1)) -val lines = ssc.networkTextStream(args(1), args(2).toInt) +val lines = ssc.socketTextStream(args(1), args(2).toInt) // Split the lines into words, count them, and print some of the counts on the master val words = lines.flatMap(_.split(" ")) @@ -213,6 +227,8 @@ wordCounts.print() ssc.start() {% endhighlight %} +The `socketTextStream` returns a DStream of lines received from a TCP socket-based source. The `lines` DStream is _transformed_ into a DStream using the `flatMap` operation, where each line is split into words. This `words` DStream is then mapped to a DStream of `(word, 1)` pairs, which is finally reduced to get the word counts. `wordCounts.print()` will print 10 of the counts generated every second. + To run this example on your local machine, you need to first run a Netcat server by using {% highlight bash %} @@ -260,6 +276,31 @@ Time: 1357008430000 ms </td> </table> +You can find more examples in `<Spark repo>/streaming/src/main/scala/spark/streaming/examples/`. They can be run in the similar manner using `./run spark.streaming.examples....` . Executing without any parameter would give the required parameter list. Further explanation to run them can be found in comments in the files. + +# DStream Persistence +Similar to RDDs, DStreams also allow developers to persist the stream's data in memory. That is, using `persist()` method on a DStream would automatically persist every RDD of that DStream in memory. This is useful if the data in the DStream will be computed multiple times (e.g., multiple operations on the same data). For window-based operations like `reduceByWindow` and `reduceByKeyAndWindow` and state-based operations like `updateStateByKey`, this is implicitly true. Hence, DStreams generated by window-based operations are automatically persisted in memory, without the developer calling `persist()`. + +For input streams that receive data from the network (that is, subclasses of NetworkInputDStream like FlumeInputDStream and KafkaInputDStream), the default persistence level is set to replicate the data to two nodes for fault-tolerance. + +Note that, unlike RDDs, the default persistence level of DStreams keeps the data serialized in memory. This is further discussed in the [Performance Tuning](#memory-tuning) section. More information on different persistence levels can be found in [Spark Programming Guide](scala-programming-guide.html#rdd-persistence). + +# RDD Checkpointing within DStreams +DStreams created by stateful operations like `updateStateByKey` require the RDDs in the DStream to be periodically saved to HDFS files for checkpointing. This is because, unless checkpointed, the lineage of operations of the state RDDs can increase indefinitely (since each RDD in the DStream depends on the previous RDD). This leads to two problems - (i) the size of Spark tasks increase proportionally with the RDD lineage leading higher task launch times, (ii) no limit on the amount of recomputation required on failure. Checkpointing RDDs at some interval by writing them to HDFS allows the lineage to be truncated. Note that checkpointing also incurs the cost of saving to HDFS which may cause the corresponding batch to take longer to process. Hence, the interval of checkpointing needs to be set carefully. At small batch sizes (say 1 second), checkpointing every batch may significantly reduce operation throughput. Conversely, checkpointing too slowly causes the lineage and task sizes to grow which may have detrimental effects. Typically, a checkpoint interval of 5 - 10 times of sliding interval of a DStream is good setting to try. + +To enable checkpointing, the developer has to provide the HDFS path to which RDD will be saved. This is done by using + +{% highlight scala %} +ssc.checkpoint(hdfsPath) // assuming ssc is the StreamingContext +{% endhighlight %} + +The interval of checkpointing of a DStream can be set by using + +{% highlight scala %} +dstream.checkpoint(checkpointInterval) // checkpointInterval must be a multiple of slide duration of dstream +{% endhighlight %} + +For DStreams that must be checkpointed (that is, DStreams created by `updateStateByKey` and `reduceByKeyAndWindow` with inverse function), the checkpoint interval of the DStream is by default set to a multiple of the DStream's sliding interval such that its at least 10 seconds. # Performance Tuning @@ -273,17 +314,21 @@ Getting the best performance of a Spark Streaming application on a cluster requi There are a number of optimizations that can be done in Spark to minimize the processing time of each batch. These have been discussed in detail in [Tuning Guide](tuning.html). This section highlights some of the most important ones. ### Level of Parallelism -Cluster resources maybe underutilized if the number of parallel tasks used in any stage of the computation is not high enough. For example, for distributed reduce operations like `reduceByKey` and `reduceByKeyAndWindow`, the default number of parallel tasks is 8. You can pass the level of parallelism as an argument (see the [`spark.PairDStreamFunctions`](api/streaming/index.html#spark.PairDStreamFunctions) documentation), or set the system property `spark.default.parallelism` to change the default. +Cluster resources maybe under-utilized if the number of parallel tasks used in any stage of the computation is not high enough. For example, for distributed reduce operations like `reduceByKey` and `reduceByKeyAndWindow`, the default number of parallel tasks is 8. You can pass the level of parallelism as an argument (see the [`spark.PairDStreamFunctions`](api/streaming/index.html#spark.PairDStreamFunctions) documentation), or set the system property `spark.default.parallelism` to change the default. ### Data Serialization The overhead of data serialization can be significant, especially when sub-second batch sizes are to be achieved. There are two aspects to it. -* Serialization of RDD data in Spark: Please refer to the detailed discussion on data serialization in the [Tuning Guide](tuning.html). However, note that unlike Spark, by default RDDs are persisted as serialized byte arrays to minimize pauses related to GC. -* Serialization of input data: To ingest external data into Spark, data received as bytes (say, from the network) needs to deserialized from bytes and re-serialized into Spark's serialization format. Hence, the deserialization overhead of input data may be a bottleneck. + +* **Serialization of RDD data in Spark**: Please refer to the detailed discussion on data serialization in the [Tuning Guide](tuning.html). However, note that unlike Spark, by default RDDs are persisted as serialized byte arrays to minimize pauses related to GC. + +* **Serialization of input data**: To ingest external data into Spark, data received as bytes (say, from the network) needs to deserialized from bytes and re-serialized into Spark's serialization format. Hence, the deserialization overhead of input data may be a bottleneck. ### Task Launching Overheads If the number of tasks launched per second is high (say, 50 or more per second), then the overhead of sending out tasks to the slaves maybe significant and will make it hard to achieve sub-second latencies. The overhead can be reduced by the following changes: -* Task Serialization: Using Kryo serialization for serializing tasks can reduced the task sizes, and therefore reduce the time taken to send them to the slaves. -* Execution mode: Running Spark in Standalone mode or coarse-grained Mesos mode leads to better task launch times than the fine-grained Mesos mode. Please refer to the [Running on Mesos guide](running-on-mesos.html) for more details. + +* **Task Serialization**: Using Kryo serialization for serializing tasks can reduced the task sizes, and therefore reduce the time taken to send them to the slaves. + +* **Execution mode**: Running Spark in Standalone mode or coarse-grained Mesos mode leads to better task launch times than the fine-grained Mesos mode. Please refer to the [Running on Mesos guide](running-on-mesos.html) for more details. These changes may reduce batch processing time by 100s of milliseconds, thus allowing sub-second batch size to be viable. ## Setting the Right Batch Size @@ -292,22 +337,182 @@ For a Spark Streaming application running on a cluster to be stable, the process A good approach to figure out the right batch size for your application is to test it with a conservative batch size (say, 5-10 seconds) and a low data rate. To verify whether the system is able to keep up with data rate, you can check the value of the end-to-end delay experienced by each processed batch (in the Spark master logs, find the line having the phrase "Total delay"). If the delay is maintained to be less than the batch size, then system is stable. Otherwise, if the delay is continuously increasing, it means that the system is unable to keep up and it therefore unstable. Once you have an idea of a stable configuration, you can try increasing the data rate and/or reducing the batch size. Note that momentary increase in the delay due to temporary data rate increases maybe fine as long as the delay reduces back to a low value (i.e., less than batch size). ## 24/7 Operation -By default, Spark does not forget any of the metadata (RDDs generated, stages processed, etc.). But for a Spark Streaming application to operate 24/7, it is necessary for Spark to do periodic cleanup of it metadata. This can be enabled by setting the Java system property `spark.cleaner.delay` to the number of minutes you want any metadata to persist. For example, setting `spark.cleaner.delay` to 10 would cause Spark periodically cleanup all metadata and persisted RDDs that are older than 10 minutes. Note, that this property needs to be set before the SparkContext is created. +By default, Spark does not forget any of the metadata (RDDs generated, stages processed, etc.). But for a Spark Streaming application to operate 24/7, it is necessary for Spark to do periodic cleanup of it metadata. This can be enabled by setting the Java system property `spark.cleaner.ttl` to the number of seconds you want any metadata to persist. For example, setting `spark.cleaner.ttl` to 600 would cause Spark periodically cleanup all metadata and persisted RDDs that are older than 10 minutes. Note, that this property needs to be set before the SparkContext is created. This value is closely tied with any window operation that is being used. Any window operation would require the input data to be persisted in memory for at least the duration of the window. Hence it is necessary to set the delay to at least the value of the largest window operation used in the Spark Streaming application. If this delay is set too low, the application will throw an exception saying so. ## Memory Tuning Tuning the memory usage and GC behavior of Spark applications have been discussed in great detail in the [Tuning Guide](tuning.html). It is recommended that you read that. In this section, we highlight a few customizations that are strongly recommended to minimize GC related pauses in Spark Streaming applications and achieving more consistent batch processing times. -* <b>Default persistence level of DStreams</b>: Unlike RDDs, the default persistence level of DStreams serializes the data in memory (that is, [StorageLevel.MEMORY_ONLY_SER](api/core/index.html#spark.storage.StorageLevel$) for DStream compared to [StorageLevel.MEMORY_ONLY](api/core/index.html#spark.storage.StorageLevel$) for RDDs). Even though keeping the data serialized incurs a higher serialization overheads, it significantly reduces GC pauses. +* **Default persistence level of DStreams**: Unlike RDDs, the default persistence level of DStreams serializes the data in memory (that is, [StorageLevel.MEMORY_ONLY_SER](api/core/index.html#spark.storage.StorageLevel$) for DStream compared to [StorageLevel.MEMORY_ONLY](api/core/index.html#spark.storage.StorageLevel$) for RDDs). Even though keeping the data serialized incurs a higher serialization overheads, it significantly reduces GC pauses. + +* **Concurrent garbage collector**: Using the concurrent mark-and-sweep GC further minimizes the variability of GC pauses. Even though concurrent GC is known to reduce the overall processing throughput of the system, its use is still recommended to achieve more consistent batch processing times. + +# Fault-tolerance Properties +In this section, we are going to discuss the behavior of Spark Streaming application in the event of a node failure. To understand this, let us remember the basic fault-tolerance properties of Spark's RDDs. + + 1. An RDD is an immutable, and deterministically re-computable, distributed dataset. Each RDD remembers the lineage of deterministic operations that were used on a fault-tolerant input dataset to create it. + 1. If any partition of an RDD is lost due to a worker node failure, then that partition can be re-computed from the original fault-tolerant dataset using the lineage of operations. + +Since all data transformations in Spark Streaming are based on RDD operations, as long as the input dataset is present, all intermediate data can recomputed. Keeping these properties in mind, we are going to discuss the failure semantics in more detail. + +## Failure of a Worker Node + +There are two failure behaviors based on which input sources are used. + +1. _Using HDFS files as input source_ - Since the data is reliably stored on HDFS, all data can re-computed and therefore no data will be lost due to any failure. +1. _Using any input source that receives data through a network_ - For network-based data sources like Kafka and Flume, the received input data is replicated in memory between nodes of the cluster (default replication factor is 2). So if a worker node fails, then the system can recompute the lost from the the left over copy of the input data. However, if the worker node where a network receiver was running fails, then a tiny bit of data may be lost, that is, the data received by the system but not yet replicated to other node(s). The receiver will be started on a different node and it will continue to receive data. + +Since all data is modeled as RDDs with their lineage of deterministic operations, any recomputation always leads to the same result. As a result, all DStream transformations are guaranteed to have _exactly-once_ semantics. That is, the final transformed result will be same even if there were was a worker node failure. However, output operations (like `foreach`) have _at-least once_ semantics, that is, the transformed data may get written to an external entity more than once in the event of a worker failure. While this is acceptable for saving to HDFS using the `saveAs*Files` operations (as the file will simply get over-written by the same data), additional transactions-like mechanisms may be necessary to achieve exactly-once semantics for output operations. + +## Failure of the Driver Node +A system that is required to operate 24/7 needs to be able tolerate the failure of the driver node as well. Spark Streaming does this by saving the state of the DStream computation periodically to a HDFS file, that can be used to restart the streaming computation in the event of a failure of the driver node. This checkpointing is enabled by setting a HDFS directory for checkpointing using `ssc.checkpoint(<checkpoint directory>)` as described [earlier](#rdd-checkpointing-within-dstreams). To elaborate, the following state is periodically saved to a file. + +1. The DStream operator graph (input streams, output streams, etc.) +1. The configuration of each DStream (checkpoint interval, etc.) +1. The RDD checkpoint files of each DStream + +All this is periodically saved in the file `<checkpoint directory>/graph`. To recover, a new Streaming Context can be created with this directory by using + +{% highlight scala %} +val ssc = new StreamingContext(checkpointDirectory) +{% endhighlight %} + +On calling `ssc.start()` on this new context, the following steps are taken by the system + +1. Schedule the transformations and output operations for all the time steps between the time when the driver failed and when it was restarted. This is also done for those time steps that were scheduled but not processed due to the failure. This will make the system recompute all the intermediate data from the checkpointed RDD files, etc. +1. Restart the network receivers, if any, and continue receiving new data. -* <b>Concurrent garbage collector</b>: Using the concurrent mark-and-sweep GC further minimizes the variability of GC pauses. Even though concurrent GC is known to reduce the overall processing throughput of the system, its use is still recommended to achieve more consistent batch processing times. +In the current _alpha_ release, there are two different failure behaviors based on which input sources are used. -# Master Fault-tolerance (Alpha) -TODO +1. _Using HDFS files as input source_ - Since the data is reliably stored on HDFS, all data can re-computed and therefore no data will be lost due to any failure. +1. _Using any input source that receives data through a network_ - As aforesaid, the received input data is replicated in memory to multiple nodes. Since, all the data in the Spark worker's memory is lost when the Spark driver fails, the past input data will not be accessible and driver recovers. Hence, if stateful and window-based operations are used (like `updateStateByKey`, `window`, `countByValueAndWindow`, etc.), then the intermediate state will not be recovered completely. + +In future releases, this behaviour will be fixed for all input sources, that is, all data will be recovered irrespective of which input sources are used. Note that for non-stateful transformations like `map`, `count`, and `reduceByKey`, with _all_ input streams, the system, upon restarting, will continue to receive and process new data. + +To better understand the behavior of the system under driver failure with a HDFS source, lets consider what will happen with a file input stream Specifically, in the case of the file input stream, it will correctly identify new files that were created while the driver was down and process them in the same way as it would have if the driver had not failed. To explain further in the case of file input stream, we shall use an example. Lets say, files are being generated every second, and a Spark Streaming program reads every new file and output the number of lines in the file. This is what the sequence of outputs would be with and without a driver failure. + +<table class="table"> + <!-- Results table headers --> + <tr> + <th> Time </th> + <th> Number of lines in input file </th> + <th> Output without driver failure </th> + <th> Output with driver failure </th> + </tr> + <tr> + <td>1</td> + <td>10</td> + <td>10</td> + <td>10</td> + </tr> + <tr> + <td>2</td> + <td>20</td> + <td>20</td> + <td>20</td> + </tr> + <tr> + <td>3</td> + <td>30</td> + <td>30</td> + <td>30</td> + </tr> + <tr> + <td>4</td> + <td>40</td> + <td>40</td> + <td>[DRIVER FAILS]<br />no output</td> + </tr> + <tr> + <td>5</td> + <td>50</td> + <td>50</td> + <td>no output</td> + </tr> + <tr> + <td>6</td> + <td>60</td> + <td>60</td> + <td>no output</td> + </tr> + <tr> + <td>7</td> + <td>70</td> + <td>70</td> + <td>[DRIVER RECOVERS]<br />40, 50, 60, 70</td> + </tr> + <tr> + <td>8</td> + <td>80</td> + <td>80</td> + <td>80</td> + </tr> + <tr> + <td>9</td> + <td>90</td> + <td>90</td> + <td>90</td> + </tr> + <tr> + <td>10</td> + <td>100</td> + <td>100</td> + <td>100</td> + </tr> +</table> + +If the driver had crashed in the middle of the processing of time 3, then it will process time 3 and output 30 after recovery. + +# Java API + +Similar to [Spark's Java API](java-programming-guide.html), we also provide a Java API for Spark Streaming which allows all its features to be accessible from a Java program. This is defined in [spark.streaming.api.java] (api/streaming/index.html#spark.streaming.api.java.package) package and includes [JavaStreamingContext](api/streaming/index.html#spark.streaming.api.java.JavaStreamingContext) and [JavaDStream](api/streaming/index.html#spark.streaming.api.java.JavaDStream) classes that provide the same methods as their Scala counterparts, but take Java functions (that is, Function, and Function2) and return Java data and collection types. Some of the key points to note are: + +1. Functions for transformations must be implemented as subclasses of [Function](api/core/index.html#spark.api.java.function.Function) and [Function2](api/core/index.html#spark.api.java.function.Function2) +1. Unlike the Scala API, the Java API handles DStreams for key-value pairs using a separate [JavaPairDStream](api/streaming/index.html#spark.streaming.api.java.JavaPairDStream) class(similar to [JavaRDD and JavaPairRDD](java-programming-guide.html#rdd-classes). DStream functions like `map` and `filter` are implemented separately by JavaDStreams and JavaPairDStream to return DStreams of appropriate types. + +Spark's [Java Programming Guide](java-programming-guide.html) gives more ideas about using the Java API. To extends the ideas presented for the RDDs to DStreams, we present parts of the Java version of the same NetworkWordCount example presented above. The full source code is given at `<spark repo>/examples/src/main/java/spark/streaming/examples/JavaNetworkWordCount.java` + +The streaming context and the socket stream from input source is started by using a `JavaStreamingContext`, that has the same parameters and provides the same input streams as its Scala counterpart. + +{% highlight java %} +JavaStreamingContext ssc = new JavaStreamingContext(mesosUrl, "NetworkWordCount", Seconds(1)); +JavaDStream<String> lines = ssc.socketTextStream(ip, port); +{% endhighlight %} + + +Then the `lines` are split into words by using the `flatMap` function and [FlatMapFunction](api/core/index.html#spark.api.java.function.FlatMapFunction). + +{% highlight java %} +JavaDStream<String> words = lines.flatMap( + new FlatMapFunction<String, String>() { + @Override + public Iterable<String> call(String x) { + return Lists.newArrayList(x.split(" ")); + } + }); +{% endhighlight %} + +The `words` is then mapped to a [JavaPairDStream](api/streaming/index.html#spark.streaming.api.java.JavaPairDStream) of `(word, 1)` pairs using `map` and [PairFunction](api/core/index.html#spark.api.java.function.PairFunction). This is reduced by using `reduceByKey` and [Function2](api/core/index.html#spark.api.java.function.Function2). + +{% highlight java %} +JavaPairDStream<String, Integer> wordCounts = words.map( + new PairFunction<String, String, Integer>() { + @Override + public Tuple2<String, Integer> call(String s) throws Exception { + return new Tuple2<String, Integer>(s, 1); + } + }).reduceByKey( + new Function2<Integer, Integer, Integer>() { + @Override + public Integer call(Integer i1, Integer i2) throws Exception { + return i1 + i2; + } + }); +{% endhighlight %} -* Checkpointing of DStream graph -* Recovery from master faults -* Current state and future directions \ No newline at end of file +# Where to Go from Here +* Documentation - [Scala](api/streaming/index.html#spark.streaming.package) and [Java](api/streaming/index.html#spark.streaming.api.java.package) +* More examples - [Scala](https://github.com/mesos/spark/tree/master/examples/src/main/scala/spark/streaming/examples) and [Java](https://github.com/mesos/spark/tree/master/examples/src/main/java/spark/streaming/examples) diff --git a/docs/tuning.md b/docs/tuning.md index 9aaa53cd65205901e62f39924b696abbecff4d8c..843380b9a28829898877dd092018e1b31ddabca5 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -213,10 +213,10 @@ but at a high level, managing how frequently full GC takes place can help in red Clusters will not be fully utilized unless you set the level of parallelism for each operation high enough. Spark automatically sets the number of "map" tasks to run on each file according to its size -(though you can control it through optional parameters to `SparkContext.textFile`, etc), but for -distributed "reduce" operations, such as `groupByKey` and `reduceByKey`, it uses a default value of 8. -You can pass the level of parallelism as a second argument (see the -[`spark.PairRDDFunctions`](api/core/index.html#spark.PairRDDFunctions) documentation), +(though you can control it through optional parameters to `SparkContext.textFile`, etc), and for +distributed "reduce" operations, such as `groupByKey` and `reduceByKey`, it uses the largest +parent RDD's number of partitions. You can pass the level of parallelism as a second argument +(see the [`spark.PairRDDFunctions`](api/core/index.html#spark.PairRDDFunctions) documentation), or set the system property `spark.default.parallelism` to change the default. In general, we recommend 2-3 tasks per CPU core in your cluster. @@ -233,7 +233,7 @@ number of cores in your clusters. ## Broadcasting Large Variables -Using the [broadcast functionality](scala-programming-guide#broadcast-variables) +Using the [broadcast functionality](scala-programming-guide.html#broadcast-variables) available in `SparkContext` can greatly reduce the size of each serialized task, and the cost of launching a job over a cluster. If your tasks use any large object from the driver program inside of them (e.g. a static lookup table), consider turning it into a broadcast variable. diff --git a/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh b/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh new file mode 100644 index 0000000000000000000000000000000000000000..166a884c889e4b154a4fe1b29c8aa61b4c382bed --- /dev/null +++ b/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +# These variables are automatically filled in by the mesos-ec2 script. +export MESOS_MASTERS="{{master_list}}" +export MESOS_SLAVES="{{slave_list}}" +export MESOS_ZOO_LIST="{{zoo_list}}" +export MESOS_HDFS_DATA_DIRS="{{hdfs_data_dirs}}" +export MESOS_MAPRED_LOCAL_DIRS="{{mapred_local_dirs}}" +export MESOS_SPARK_LOCAL_DIRS="{{spark_local_dirs}}" +export MODULES="{{modules}}" +export SWAP_MB="{{swap}}" diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index a5384d3bda1c61d38cec2f4a6d2e70ff1fd84938..66b1faf2cd8314d54ca4bd0ba9f2c65f0663ba55 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -82,12 +82,21 @@ def parse_args(): parser.add_option("--spot-price", metavar="PRICE", type="float", help="If specified, launch slaves as spot instances with the given " + "maximum price (in dollars)") - parser.add_option("-c", "--cluster-type", default="mesos", - help="'mesos' for a mesos cluster, 'standalone' for a standalone spark cluster (default: mesos)") + parser.add_option("--cluster-type", type="choice", metavar="TYPE", + choices=["mesos", "standalone"], default="mesos", + help="'mesos' for a Mesos cluster, 'standalone' for a standalone " + + "Spark cluster (default: mesos)") + parser.add_option("--ganglia", action="store_true", default=True, + help="Setup Ganglia monitoring on cluster (default: on). NOTE: " + + "the Ganglia page will be publicly accessible") + parser.add_option("--no-ganglia", action="store_false", dest="ganglia", + help="Disable Ganglia monitoring for the cluster") + parser.add_option("--new-scripts", action="store_true", default=False, + help="Use new spark-ec2 scripts, for Spark >= 0.7 AMIs") parser.add_option("-u", "--user", default="root", - help="The ssh user you want to connect as (default: root)") + help="The SSH user you want to connect as (default: root)") parser.add_option("--delete-groups", action="store_true", default=False, - help="When destroying a cluster, also destroy the security groups that were created") + help="When destroying a cluster, delete the security groups that were created") (opts, args) = parser.parse_args() if len(args) != 2: @@ -164,22 +173,23 @@ def launch_cluster(conn, opts, cluster_name): master_group.authorize(src_group=zoo_group) master_group.authorize('tcp', 22, 22, '0.0.0.0/0') master_group.authorize('tcp', 8080, 8081, '0.0.0.0/0') + master_group.authorize('tcp', 50030, 50030, '0.0.0.0/0') + master_group.authorize('tcp', 50070, 50070, '0.0.0.0/0') + master_group.authorize('tcp', 60070, 60070, '0.0.0.0/0') if opts.cluster_type == "mesos": - master_group.authorize('tcp', 50030, 50030, '0.0.0.0/0') - master_group.authorize('tcp', 50070, 50070, '0.0.0.0/0') - master_group.authorize('tcp', 60070, 60070, '0.0.0.0/0') master_group.authorize('tcp', 38090, 38090, '0.0.0.0/0') + if opts.ganglia: + master_group.authorize('tcp', 5080, 5080, '0.0.0.0/0') if slave_group.rules == []: # Group was just now created slave_group.authorize(src_group=master_group) slave_group.authorize(src_group=slave_group) slave_group.authorize(src_group=zoo_group) slave_group.authorize('tcp', 22, 22, '0.0.0.0/0') slave_group.authorize('tcp', 8080, 8081, '0.0.0.0/0') - if opts.cluster_type == "mesos": - slave_group.authorize('tcp', 50060, 50060, '0.0.0.0/0') - slave_group.authorize('tcp', 50075, 50075, '0.0.0.0/0') - slave_group.authorize('tcp', 60060, 60060, '0.0.0.0/0') - slave_group.authorize('tcp', 60075, 60075, '0.0.0.0/0') + slave_group.authorize('tcp', 50060, 50060, '0.0.0.0/0') + slave_group.authorize('tcp', 50075, 50075, '0.0.0.0/0') + slave_group.authorize('tcp', 60060, 60060, '0.0.0.0/0') + slave_group.authorize('tcp', 60075, 60075, '0.0.0.0/0') if zoo_group.rules == []: # Group was just now created zoo_group.authorize(src_group=master_group) zoo_group.authorize(src_group=slave_group) @@ -358,19 +368,38 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): # Deploy configuration files and run setup scripts on a newly launched # or started EC2 cluster. def setup_cluster(conn, master_nodes, slave_nodes, zoo_nodes, opts, deploy_ssh_key): - print "Deploying files to master..." - deploy_files(conn, "deploy.generic", opts, master_nodes, slave_nodes, zoo_nodes) master = master_nodes[0].public_dns_name if deploy_ssh_key: print "Copying SSH key %s to master..." % opts.identity_file ssh(master, opts, 'mkdir -p ~/.ssh') scp(master, opts, opts.identity_file, '~/.ssh/id_rsa') ssh(master, opts, 'chmod 600 ~/.ssh/id_rsa') - print "Running setup on master..." + if opts.cluster_type == "mesos": - setup_mesos_cluster(master, opts) + modules = ['ephemeral-hdfs', 'persistent-hdfs', 'mesos'] elif opts.cluster_type == "standalone": - setup_standalone_cluster(master, slave_nodes, opts) + modules = ['ephemeral-hdfs', 'persistent-hdfs', 'spark-standalone'] + + if opts.ganglia: + modules.append('ganglia') + + if opts.new_scripts: + # NOTE: We should clone the repository before running deploy_files to + # prevent ec2-variables.sh from being overwritten + ssh(master, opts, "rm -rf spark-ec2 && git clone https://github.com/mesos/spark-ec2.git") + + print "Deploying files to master..." + deploy_files(conn, "deploy.generic", opts, master_nodes, slave_nodes, + zoo_nodes, modules) + + print "Running setup on master..." + if not opts.new_scripts: + if opts.cluster_type == "mesos": + setup_mesos_cluster(master, opts) + elif opts.cluster_type == "standalone": + setup_standalone_cluster(master, slave_nodes, opts) + else: + setup_spark_cluster(master, opts) print "Done!" def setup_mesos_cluster(master, opts): @@ -383,6 +412,17 @@ def setup_standalone_cluster(master, slave_nodes, opts): ssh(master, opts, "echo \"%s\" > spark/conf/slaves" % (slave_ips)) ssh(master, opts, "/root/spark/bin/start-all.sh") +def setup_spark_cluster(master, opts): + ssh(master, opts, "chmod u+x spark-ec2/setup.sh") + ssh(master, opts, "spark-ec2/setup.sh") + if opts.cluster_type == "mesos": + print "Mesos cluster started at http://%s:8080" % master + elif opts.cluster_type == "standalone": + print "Spark standalone cluster started at http://%s:8080" % master + + if opts.ganglia: + print "Ganglia started at http://%s:5080/ganglia" % master + # Wait for a whole cluster (masters, slaves and ZooKeeper) to start up def wait_for_cluster(conn, wait_secs, master_nodes, slave_nodes, zoo_nodes): @@ -427,7 +467,8 @@ def get_num_disks(instance_type): # cluster (e.g. lists of masters and slaves). Files are only deployed to # the first master instance in the cluster, and we expect the setup # script to be run on that instance to copy them to other nodes. -def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, zoo_nodes): +def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, zoo_nodes, + modules): active_master = master_nodes[0].public_dns_name num_disks = get_num_disks(opts.instance_type) @@ -459,7 +500,9 @@ def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, zoo_nodes): "cluster_url": cluster_url, "hdfs_data_dirs": hdfs_data_dirs, "mapred_local_dirs": mapred_local_dirs, - "spark_local_dirs": spark_local_dirs + "spark_local_dirs": spark_local_dirs, + "swap": str(opts.swap), + "modules": '\n'.join(modules) } # Create a temp directory in which we will place all the files to be diff --git a/examples/pom.xml b/examples/pom.xml index f43af670c613fb6fa3dcaa09717b3e16e83af6da..7d975875fac3afe8e733c3549bcc4437d0f236f6 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -20,11 +20,10 @@ <artifactId>jetty-server</artifactId> </dependency> <dependency> - <groupId>org.twitter4j</groupId> - <artifactId>twitter4j-stream</artifactId> - <version>3.0.3</version> + <groupId>com.twitter</groupId> + <artifactId>algebird-core_2.9.2</artifactId> + <version>0.1.8</version> </dependency> - <dependency> <groupId>org.scalatest</groupId> <artifactId>scalatest_${scala.version}</artifactId> diff --git a/examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java b/examples/src/main/java/spark/streaming/examples/JavaFlumeEventCount.java similarity index 100% rename from examples/src/main/scala/spark/streaming/examples/JavaFlumeEventCount.java rename to examples/src/main/java/spark/streaming/examples/JavaFlumeEventCount.java diff --git a/examples/src/main/scala/spark/streaming/examples/JavaNetworkWordCount.java b/examples/src/main/java/spark/streaming/examples/JavaNetworkWordCount.java similarity index 95% rename from examples/src/main/scala/spark/streaming/examples/JavaNetworkWordCount.java rename to examples/src/main/java/spark/streaming/examples/JavaNetworkWordCount.java index 4299febfd6ca3b57569ea1b22e772015975ec8de..0e9eadd01b7111b7c98258a2d3e252bbb09c7618 100644 --- a/examples/src/main/scala/spark/streaming/examples/JavaNetworkWordCount.java +++ b/examples/src/main/java/spark/streaming/examples/JavaNetworkWordCount.java @@ -23,7 +23,7 @@ import spark.streaming.api.java.JavaStreamingContext; */ public class JavaNetworkWordCount { public static void main(String[] args) { - if (args.length < 2) { + if (args.length < 3) { System.err.println("Usage: NetworkWordCount <master> <hostname> <port>\n" + "In local mode, <master> should be 'local[n]' with n > 1"); System.exit(1); @@ -35,7 +35,7 @@ public class JavaNetworkWordCount { // Create a NetworkInputDStream on target ip:port and count the // words in input stream of \n delimited test (eg. generated by 'nc') - JavaDStream<String> lines = ssc.networkTextStream(args[1], Integer.parseInt(args[2])); + JavaDStream<String> lines = ssc.socketTextStream(args[1], Integer.parseInt(args[2])); JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() { @Override public Iterable<String> call(String x) { diff --git a/examples/src/main/scala/spark/streaming/examples/JavaQueueStream.java b/examples/src/main/java/spark/streaming/examples/JavaQueueStream.java similarity index 100% rename from examples/src/main/scala/spark/streaming/examples/JavaQueueStream.java rename to examples/src/main/java/spark/streaming/examples/JavaQueueStream.java diff --git a/examples/src/main/scala/spark/examples/LogQuery.scala b/examples/src/main/scala/spark/examples/LogQuery.scala new file mode 100644 index 0000000000000000000000000000000000000000..5330b8da9444f46d61cd2af3403f53c26da72e4a --- /dev/null +++ b/examples/src/main/scala/spark/examples/LogQuery.scala @@ -0,0 +1,66 @@ +package spark.examples + +import spark.SparkContext +import spark.SparkContext._ +/** + * Executes a roll up-style query against Apache logs. + */ +object LogQuery { + val exampleApacheLogs = List( + """10.10.10.10 - "FRED" [18/Jan/2013:17:56:07 +1100] "GET http://images.com/2013/Generic.jpg + | HTTP/1.1" 304 315 "http://referall.com/" "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; + | GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; .NET CLR + | 3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR + | 3.5.30729; Release=ARP)" "UD-1" - "image/jpeg" "whatever" 0.350 "-" - "" 265 923 934 "" + | 62.24.11.25 images.com 1358492167 - Whatup""".stripMargin.replace("\n", ""), + """10.10.10.10 - "FRED" [18/Jan/2013:18:02:37 +1100] "GET http://images.com/2013/Generic.jpg + | HTTP/1.1" 304 306 "http:/referall.com" "Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1; + | GTB7.4; .NET CLR 2.0.50727; .NET CLR 3.0.04506.30; .NET CLR 3.0.04506.648; .NET CLR + | 3.5.21022; .NET CLR 3.0.4506.2152; .NET CLR 1.0.3705; .NET CLR 1.1.4322; .NET CLR + | 3.5.30729; Release=ARP)" "UD-1" - "image/jpeg" "whatever" 0.352 "-" - "" 256 977 988 "" + | 0 73.23.2.15 images.com 1358492557 - Whatup""".stripMargin.replace("\n", "") + ) + + def main(args: Array[String]) { + if (args.length == 0) { + System.err.println("Usage: LogQuery <master> [logFile]") + System.exit(1) + } + val sc = new SparkContext(args(0), "Log Query") + + val dataSet = + if (args.length == 2) sc.textFile(args(1)) + else sc.parallelize(exampleApacheLogs) + + val apacheLogRegex = + """^([\d.]+) (\S+) (\S+) \[([\w\d:/]+\s[+\-]\d{4})\] "(.+?)" (\d{3}) ([\d\-]+) "([^"]+)" "([^"]+)".*""".r + + /** Tracks the total query count and number of aggregate bytes for a particular group. */ + class Stats(val count: Int, val numBytes: Int) extends Serializable { + def merge(other: Stats) = new Stats(count + other.count, numBytes + other.numBytes) + override def toString = "bytes=%s\tn=%s".format(numBytes, count) + } + + def extractKey(line: String): (String, String, String) = { + apacheLogRegex.findFirstIn(line) match { + case Some(apacheLogRegex(ip, _, user, dateTime, query, status, bytes, referer, ua)) => + if (user != "\"-\"") (ip, user, query) + else (null, null, null) + case _ => (null, null, null) + } + } + + def extractStats(line: String): Stats = { + apacheLogRegex.findFirstIn(line) match { + case Some(apacheLogRegex(ip, _, user, dateTime, query, status, bytes, referer, ua)) => + new Stats(1, bytes.toInt) + case _ => new Stats(1, 0) + } + } + + dataSet.map(line => (extractKey(line), extractStats(line))) + .reduceByKey((a, b) => a.merge(b)) + .collect().foreach{ + case (user, query) => println("%s\t%s".format(user, query))} + } +} diff --git a/examples/src/main/scala/spark/streaming/examples/ActorWordCount.scala b/examples/src/main/scala/spark/streaming/examples/ActorWordCount.scala new file mode 100644 index 0000000000000000000000000000000000000000..76293fbb96bf7be6b2a61af1b597a3e4b2d6ddca --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/ActorWordCount.scala @@ -0,0 +1,157 @@ +package spark.streaming.examples + +import scala.collection.mutable.LinkedList +import scala.util.Random + +import akka.actor.Actor +import akka.actor.ActorRef +import akka.actor.Props +import akka.actor.actorRef2Scala + +import spark.streaming.Seconds +import spark.streaming.StreamingContext +import spark.streaming.StreamingContext.toPairDStreamFunctions +import spark.streaming.receivers.Receiver +import spark.util.AkkaUtils + +case class SubscribeReceiver(receiverActor: ActorRef) +case class UnsubscribeReceiver(receiverActor: ActorRef) + +/** + * Sends the random content to every receiver subscribed with 1/2 + * second delay. + */ +class FeederActor extends Actor { + + val rand = new Random() + var receivers: LinkedList[ActorRef] = new LinkedList[ActorRef]() + + val strings: Array[String] = Array("words ", "may ", "count ") + + def makeMessage(): String = { + val x = rand.nextInt(3) + strings(x) + strings(2 - x) + } + + /* + * A thread to generate random messages + */ + new Thread() { + override def run() { + while (true) { + Thread.sleep(500) + receivers.foreach(_ ! makeMessage) + } + } + }.start() + + def receive: Receive = { + + case SubscribeReceiver(receiverActor: ActorRef) => + println("received subscribe from %s".format(receiverActor.toString)) + receivers = LinkedList(receiverActor) ++ receivers + + case UnsubscribeReceiver(receiverActor: ActorRef) => + println("received unsubscribe from %s".format(receiverActor.toString)) + receivers = receivers.dropWhile(x => x eq receiverActor) + + } +} + +/** + * A sample actor as receiver, is also simplest. This receiver actor + * goes and subscribe to a typical publisher/feeder actor and receives + * data. + * + * @see [[spark.streaming.examples.FeederActor]] + */ +class SampleActorReceiver[T: ClassManifest](urlOfPublisher: String) +extends Actor with Receiver { + + lazy private val remotePublisher = context.actorFor(urlOfPublisher) + + override def preStart = remotePublisher ! SubscribeReceiver(context.self) + + def receive = { + case msg ⇒ context.parent ! pushBlock(msg.asInstanceOf[T]) + } + + override def postStop() = remotePublisher ! UnsubscribeReceiver(context.self) + +} + +/** + * A sample feeder actor + * + * Usage: FeederActor <hostname> <port> + * <hostname> and <port> describe the AkkaSystem that Spark Sample feeder would start on. + */ +object FeederActor { + + def main(args: Array[String]) { + if(args.length < 2){ + System.err.println( + "Usage: FeederActor <hostname> <port>\n" + ) + System.exit(1) + } + val Seq(host, port) = args.toSeq + + + val actorSystem = AkkaUtils.createActorSystem("test", host, port.toInt)._1 + val feeder = actorSystem.actorOf(Props[FeederActor], "FeederActor") + + println("Feeder started as:" + feeder) + + actorSystem.awaitTermination(); + } +} + +/** + * A sample word count program demonstrating the use of plugging in + * Actor as Receiver + * Usage: ActorWordCount <master> <hostname> <port> + * <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1. + * <hostname> and <port> describe the AkkaSystem that Spark Sample feeder is running on. + * + * To run this example locally, you may run Feeder Actor as + * `$ ./run spark.streaming.examples.FeederActor 127.0.1.1 9999` + * and then run the example + * `$ ./run spark.streaming.examples.ActorWordCount local[2] 127.0.1.1 9999` + */ +object ActorWordCount { + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println( + "Usage: ActorWordCount <master> <hostname> <port>" + + "In local mode, <master> should be 'local[n]' with n > 1") + System.exit(1) + } + + val Seq(master, host, port) = args.toSeq + + // Create the context and set the batch size + val ssc = new StreamingContext(master, "ActorWordCount", Seconds(2)) + + /* + * Following is the use of actorStream to plug in custom actor as receiver + * + * An important point to note: + * Since Actor may exist outside the spark framework, It is thus user's responsibility + * to ensure the type safety, i.e type of data received and InputDstream + * should be same. + * + * For example: Both actorStream and SampleActorReceiver are parameterized + * to same type to ensure type safety. + */ + + val lines = ssc.actorStream[String]( + Props(new SampleActorReceiver[String]("akka://test@%s:%s/user/FeederActor".format( + host, port.toInt))), "SampleReceiver") + + //compute wordcount + lines.flatMap(_.split("\\s+")).map(x => (x, 1)).reduceByKey(_ + _).print() + + ssc.start() + } +} diff --git a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala index fe55db6e2c6314739cf3c3e2fd4f34511aff7b6f..9b135a5c54cf376bccfb9f334ccbd070ddb34924 100644 --- a/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala +++ b/examples/src/main/scala/spark/streaming/examples/KafkaWordCount.scala @@ -10,22 +10,34 @@ import spark.streaming.StreamingContext._ import spark.storage.StorageLevel import spark.streaming.util.RawTextHelper._ +/** + * Consumes messages from one or more topics in Kafka and does wordcount. + * Usage: KafkaWordCount <master> <zkQuorum> <group> <topics> <numThreads> + * <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1. + * <zkQuorum> is a list of one or more zookeeper servers that make quorum + * <group> is the name of kafka consumer group + * <topics> is a list of one or more kafka topics to consume from + * <numThreads> is the number of threads the kafka consumer should use + * + * Example: + * `./run spark.streaming.examples.KafkaWordCount local[2] zoo01,zoo02,zoo03 my-consumer-group topic1,topic2 1` + */ object KafkaWordCount { def main(args: Array[String]) { - if (args.length < 6) { - System.err.println("Usage: KafkaWordCount <master> <hostname> <port> <group> <topics> <numThreads>") + if (args.length < 5) { + System.err.println("Usage: KafkaWordCount <master> <zkQuorum> <group> <topics> <numThreads>") System.exit(1) } - val Array(master, hostname, port, group, topics, numThreads) = args + val Array(master, zkQuorum, group, topics, numThreads) = args val sc = new SparkContext(master, "KafkaWordCount") val ssc = new StreamingContext(sc, Seconds(2)) ssc.checkpoint("checkpoint") val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap - val lines = ssc.kafkaStream[String](hostname, port.toInt, group, topicpMap) + val lines = ssc.kafkaStream[String](zkQuorum, group, topicpMap) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2) wordCounts.print() @@ -38,16 +50,16 @@ object KafkaWordCount { object KafkaWordCountProducer { def main(args: Array[String]) { - if (args.length < 3) { - System.err.println("Usage: KafkaWordCountProducer <hostname> <port> <topic> <messagesPerSec> <wordsPerMessage>") + if (args.length < 2) { + System.err.println("Usage: KafkaWordCountProducer <zkQuorum> <topic> <messagesPerSec> <wordsPerMessage>") System.exit(1) } - val Array(hostname, port, topic, messagesPerSec, wordsPerMessage) = args + val Array(zkQuorum, topic, messagesPerSec, wordsPerMessage) = args // Zookeper connection properties val props = new Properties() - props.put("zk.connect", hostname + ":" + port) + props.put("zk.connect", zkQuorum) props.put("serializer.class", "kafka.serializer.StringEncoder") val config = new ProducerConfig(props) diff --git a/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala b/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala index 32f7d57bea485a6d87be8537635dc2793d1666e6..5ac6d19b349da7dd9e4be4d019db974e63107038 100644 --- a/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala +++ b/examples/src/main/scala/spark/streaming/examples/NetworkWordCount.scala @@ -16,7 +16,7 @@ import spark.streaming.StreamingContext._ */ object NetworkWordCount { def main(args: Array[String]) { - if (args.length < 2) { + if (args.length < 3) { System.err.println("Usage: NetworkWordCount <master> <hostname> <port>\n" + "In local mode, <master> should be 'local[n]' with n > 1") System.exit(1) @@ -27,7 +27,7 @@ object NetworkWordCount { // Create a NetworkInputDStream on target ip:port and count the // words in input stream of \n delimited test (eg. generated by 'nc') - val lines = ssc.networkTextStream(args(1), args(2).toInt) + val lines = ssc.socketTextStream(args(1), args(2).toInt) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) wordCounts.print() diff --git a/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala b/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala index 2eec777c54e8d21da0a686c3768202ba1c0a1b93..66e709b7a333836fd1950ac3da50d8b87194fea1 100644 --- a/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala +++ b/examples/src/main/scala/spark/streaming/examples/RawNetworkGrep.scala @@ -37,7 +37,7 @@ object RawNetworkGrep { RawTextHelper.warmUp(ssc.sc) val rawStreams = (1 to numStreams).map(_ => - ssc.rawNetworkStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray + ssc.rawSocketStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray val union = ssc.union(rawStreams) union.filter(_.contains("the")).count().foreach(r => println("Grep count: " + r.collect().mkString)) diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala new file mode 100644 index 0000000000000000000000000000000000000000..39a1a702eeae925f278c0b692e2c929bf5e7b74e --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdCMS.scala @@ -0,0 +1,93 @@ +package spark.streaming.examples + +import spark.streaming.{Seconds, StreamingContext} +import spark.storage.StorageLevel +import com.twitter.algebird._ +import spark.streaming.StreamingContext._ +import spark.SparkContext._ + +/** + * Illustrates the use of the Count-Min Sketch, from Twitter's Algebird library, to compute + * windowed and global Top-K estimates of user IDs occurring in a Twitter stream. + * <br> + * <strong>Note</strong> that since Algebird's implementation currently only supports Long inputs, + * the example operates on Long IDs. Once the implementation supports other inputs (such as String), + * the same approach could be used for computing popular topics for example. + * <p> + * <p> + * <a href="http://highlyscalable.wordpress.com/2012/05/01/probabilistic-structures-web-analytics-data-mining/"> + * This blog post</a> has a good overview of the Count-Min Sketch (CMS). The CMS is a datastructure + * for approximate frequency estimation in data streams (e.g. Top-K elements, frequency of any given element, etc), + * that uses space sub-linear in the number of elements in the stream. Once elements are added to the CMS, the + * estimated count of an element can be computed, as well as "heavy-hitters" that occur more than a threshold + * percentage of the overall total count. + * <p><p> + * Algebird's implementation is a monoid, so we can succinctly merge two CMS instances in the reduce operation. + */ +object TwitterAlgebirdCMS { + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: TwitterAlgebirdCMS <master> <twitter_username> <twitter_password>" + + " [filter1] [filter2] ... [filter n]") + System.exit(1) + } + + // CMS parameters + val DELTA = 1E-3 + val EPS = 0.01 + val SEED = 1 + val PERC = 0.001 + // K highest frequency elements to take + val TOPK = 10 + + val Array(master, username, password) = args.slice(0, 3) + val filters = args.slice(3, args.length) + + val ssc = new StreamingContext(master, "TwitterAlgebirdCMS", Seconds(10)) + val stream = ssc.twitterStream(username, password, filters, StorageLevel.MEMORY_ONLY_SER) + + val users = stream.map(status => status.getUser.getId) + + val cms = new CountMinSketchMonoid(DELTA, EPS, SEED, PERC) + var globalCMS = cms.zero + val mm = new MapMonoid[Long, Int]() + var globalExact = Map[Long, Int]() + + val approxTopUsers = users.mapPartitions(ids => { + ids.map(id => cms.create(id)) + }).reduce(_ ++ _) + + val exactTopUsers = users.map(id => (id, 1)) + .reduceByKey((a, b) => a + b) + + approxTopUsers.foreach(rdd => { + if (rdd.count() != 0) { + val partial = rdd.first() + val partialTopK = partial.heavyHitters.map(id => + (id, partial.frequency(id).estimate)).toSeq.sortBy(_._2).reverse.slice(0, TOPK) + globalCMS ++= partial + val globalTopK = globalCMS.heavyHitters.map(id => + (id, globalCMS.frequency(id).estimate)).toSeq.sortBy(_._2).reverse.slice(0, TOPK) + println("Approx heavy hitters at %2.2f%% threshold this batch: %s".format(PERC, + partialTopK.mkString("[", ",", "]"))) + println("Approx heavy hitters at %2.2f%% threshold overall: %s".format(PERC, + globalTopK.mkString("[", ",", "]"))) + } + }) + + exactTopUsers.foreach(rdd => { + if (rdd.count() != 0) { + val partialMap = rdd.collect().toMap + val partialTopK = rdd.map( + {case (id, count) => (count, id)}) + .sortByKey(ascending = false).take(TOPK) + globalExact = mm.plus(globalExact.toMap, partialMap) + val globalTopK = globalExact.toSeq.sortBy(_._2).reverse.slice(0, TOPK) + println("Exact heavy hitters this batch: %s".format(partialTopK.mkString("[", ",", "]"))) + println("Exact heavy hitters overall: %s".format(globalTopK.mkString("[", ",", "]"))) + } + }) + + ssc.start() + } +} diff --git a/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala new file mode 100644 index 0000000000000000000000000000000000000000..914fba4ca22c54a2678ce2c14b2db1d1eb976265 --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/TwitterAlgebirdHLL.scala @@ -0,0 +1,71 @@ +package spark.streaming.examples + +import spark.streaming.{Seconds, StreamingContext} +import spark.storage.StorageLevel +import com.twitter.algebird.HyperLogLog._ +import com.twitter.algebird.HyperLogLogMonoid +import spark.streaming.dstream.TwitterInputDStream + +/** + * Illustrates the use of the HyperLogLog algorithm, from Twitter's Algebird library, to compute + * a windowed and global estimate of the unique user IDs occurring in a Twitter stream. + * <p> + * <p> + * This <a href="http://highlyscalable.wordpress.com/2012/05/01/probabilistic-structures-web-analytics-data-mining/"> + * blog post</a> and this + * <a href="http://highscalability.com/blog/2012/4/5/big-data-counting-how-to-count-a-billion-distinct-objects-us.html">blog post</a> + * have good overviews of HyperLogLog (HLL). HLL is a memory-efficient datastructure for estimating + * the cardinality of a data stream, i.e. the number of unique elements. + * <p><p> + * Algebird's implementation is a monoid, so we can succinctly merge two HLL instances in the reduce operation. + */ +object TwitterAlgebirdHLL { + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: TwitterAlgebirdHLL <master> <twitter_username> <twitter_password>" + + " [filter1] [filter2] ... [filter n]") + System.exit(1) + } + + /** Bit size parameter for HyperLogLog, trades off accuracy vs size */ + val BIT_SIZE = 12 + val Array(master, username, password) = args.slice(0, 3) + val filters = args.slice(3, args.length) + + val ssc = new StreamingContext(master, "TwitterAlgebirdHLL", Seconds(5)) + val stream = ssc.twitterStream(username, password, filters, StorageLevel.MEMORY_ONLY_SER) + + val users = stream.map(status => status.getUser.getId) + + val hll = new HyperLogLogMonoid(BIT_SIZE) + var globalHll = hll.zero + var userSet: Set[Long] = Set() + + val approxUsers = users.mapPartitions(ids => { + ids.map(id => hll(id)) + }).reduce(_ + _) + + val exactUsers = users.map(id => Set(id)).reduce(_ ++ _) + + approxUsers.foreach(rdd => { + if (rdd.count() != 0) { + val partial = rdd.first() + globalHll += partial + println("Approx distinct users this batch: %d".format(partial.estimatedSize.toInt)) + println("Approx distinct users overall: %d".format(globalHll.estimatedSize.toInt)) + } + }) + + exactUsers.foreach(rdd => { + if (rdd.count() != 0) { + val partial = rdd.first() + userSet ++= partial + println("Exact distinct users this batch: %d".format(partial.size)) + println("Exact distinct users overall: %d".format(userSet.size)) + println("Error rate: %2.5f%%".format(((globalHll.estimatedSize / userSet.size.toDouble) - 1) * 100)) + } + }) + + ssc.start() + } +} diff --git a/examples/src/main/scala/spark/streaming/examples/twitter/TwitterBasic.scala b/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala similarity index 55% rename from examples/src/main/scala/spark/streaming/examples/twitter/TwitterBasic.scala rename to examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala index 377bc0c98ec1f1185a97e8e559e7151e5aef7e85..fdb3a4c73c6cc6cfeca3e309c341e6e0df75f05c 100644 --- a/examples/src/main/scala/spark/streaming/examples/twitter/TwitterBasic.scala +++ b/examples/src/main/scala/spark/streaming/examples/TwitterPopularTags.scala @@ -1,19 +1,19 @@ -package spark.streaming.examples.twitter +package spark.streaming.examples -import spark.streaming.StreamingContext._ import spark.streaming.{Seconds, StreamingContext} +import StreamingContext._ import spark.SparkContext._ -import spark.storage.StorageLevel /** * Calculates popular hashtags (topics) over sliding 10 and 60 second windows from a Twitter * stream. The stream is instantiated with credentials and optionally filters supplied by the * command line arguments. + * */ -object TwitterBasic { +object TwitterPopularTags { def main(args: Array[String]) { if (args.length < 3) { - System.err.println("Usage: TwitterBasic <master> <twitter_username> <twitter_password>" + + System.err.println("Usage: TwitterPopularTags <master> <twitter_username> <twitter_password>" + " [filter1] [filter2] ... [filter n]") System.exit(1) } @@ -21,10 +21,8 @@ object TwitterBasic { val Array(master, username, password) = args.slice(0, 3) val filters = args.slice(3, args.length) - val ssc = new StreamingContext(master, "TwitterBasic", Seconds(2)) - val stream = new TwitterInputDStream(ssc, username, password, filters, - StorageLevel.MEMORY_ONLY_SER) - ssc.registerInputStream(stream) + val ssc = new StreamingContext(master, "TwitterPopularTags", Seconds(2)) + val stream = ssc.twitterStream(username, password, filters) val hashTags = stream.flatMap(status => status.getText.split(" ").filter(_.startsWith("#"))) @@ -39,22 +37,17 @@ object TwitterBasic { // Print popular hashtags topCounts60.foreach(rdd => { - if (rdd.count() != 0) { - val topList = rdd.take(5) - println("\nPopular topics in last 60 seconds (%s total):".format(rdd.count())) - topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} - } + val topList = rdd.take(5) + println("\nPopular topics in last 60 seconds (%s total):".format(rdd.count())) + topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} }) topCounts10.foreach(rdd => { - if (rdd.count() != 0) { - val topList = rdd.take(5) - println("\nPopular topics in last 10 seconds (%s total):".format(rdd.count())) - topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} - } + val topList = rdd.take(5) + println("\nPopular topics in last 10 seconds (%s total):".format(rdd.count())) + topList.foreach{case (count, tag) => println("%s (%s tweets)".format(tag, count))} }) ssc.start() } - } diff --git a/examples/src/main/scala/spark/streaming/examples/ZeroMQWordCount.scala b/examples/src/main/scala/spark/streaming/examples/ZeroMQWordCount.scala new file mode 100644 index 0000000000000000000000000000000000000000..5ed9b7cb768750e5e140802bdeca782fae134847 --- /dev/null +++ b/examples/src/main/scala/spark/streaming/examples/ZeroMQWordCount.scala @@ -0,0 +1,73 @@ +package spark.streaming.examples + +import akka.actor.ActorSystem +import akka.actor.actorRef2Scala +import akka.zeromq._ +import spark.streaming.{ Seconds, StreamingContext } +import spark.streaming.StreamingContext._ +import akka.zeromq.Subscribe + +/** + * A simple publisher for demonstration purposes, repeatedly publishes random Messages + * every one second. + */ +object SimpleZeroMQPublisher { + + def main(args: Array[String]) = { + if (args.length < 2) { + System.err.println("Usage: SimpleZeroMQPublisher <zeroMQUrl> <topic> ") + System.exit(1) + } + + val Seq(url, topic) = args.toSeq + val acs: ActorSystem = ActorSystem() + + val pubSocket = ZeroMQExtension(acs).newSocket(SocketType.Pub, Bind(url)) + val messages: Array[String] = Array("words ", "may ", "count ") + while (true) { + Thread.sleep(1000) + pubSocket ! ZMQMessage(Frame(topic) :: messages.map(x => Frame(x.getBytes)).toList) + } + acs.awaitTermination() + } +} + +/** + * A sample wordcount with ZeroMQStream stream + * + * To work with zeroMQ, some native libraries have to be installed. + * Install zeroMQ (release 2.1) core libraries. [ZeroMQ Install guide](http://www.zeromq.org/intro:get-the-software) + * + * Usage: ZeroMQWordCount <master> <zeroMQurl> <topic> + * In local mode, <master> should be 'local[n]' with n > 1 + * <zeroMQurl> and <topic> describe where zeroMq publisher is running. + * + * To run this example locally, you may run publisher as + * `$ ./run spark.streaming.examples.SimpleZeroMQPublisher tcp://127.0.1.1:1234 foo.bar` + * and run the example as + * `$ ./run spark.streaming.examples.ZeroMQWordCount local[2] tcp://127.0.1.1:1234 foo` + */ +object ZeroMQWordCount { + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println( + "Usage: ZeroMQWordCount <master> <zeroMQurl> <topic>" + + "In local mode, <master> should be 'local[n]' with n > 1") + System.exit(1) + } + val Seq(master, url, topic) = args.toSeq + + // Create the context and set the batch size + val ssc = new StreamingContext(master, "ZeroMQWordCount", Seconds(2)) + + def bytesToStringIterator(x: Seq[Seq[Byte]]) = (x.map(x => new String(x.toArray))).iterator + + //For this stream, a zeroMQ publisher should be running. + val lines = ssc.zeroMQStream(url, Subscribe(topic), bytesToStringIterator) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + } + +} \ No newline at end of file diff --git a/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala b/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala index a191321d9105801d70c76c0a557323a897b1a1ae..fba72519a924426ec6dda0923e49f3e12f039844 100644 --- a/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/spark/streaming/examples/clickstream/PageViewStream.scala @@ -27,17 +27,16 @@ object PageViewStream { val ssc = new StreamingContext("local[2]", "PageViewStream", Seconds(1)) // Create a NetworkInputDStream on target host:port and convert each line to a PageView - val pageViews = ssc.networkTextStream(host, port) - .flatMap(_.split("\n")) - .map(PageView.fromString(_)) + val pageViews = ssc.socketTextStream(host, port) + .flatMap(_.split("\n")) + .map(PageView.fromString(_)) // Return a count of views per URL seen in each batch - val pageCounts = pageViews.map(view => ((view.url, 1))).countByKey() + val pageCounts = pageViews.map(view => view.url).countByValue() // Return a sliding window of page views per URL in the last ten seconds - val slidingPageCounts = pageViews.map(view => ((view.url, 1))) - .window(Seconds(10), Seconds(2)) - .countByKey() + val slidingPageCounts = pageViews.map(view => view.url) + .countByValueAndWindow(Seconds(10), Seconds(2)) // Return the rate of error pages (a non 200 status) in each zip code over the last 30 seconds diff --git a/pom.xml b/pom.xml index 7e06cae052b58d6ebc1e5240f4810601287d9ca2..99eb17856a813a42c47cc5388637e196b6b5d4ae 100644 --- a/pom.xml +++ b/pom.xml @@ -84,9 +84,9 @@ </snapshots> </repository> <repository> - <id>typesafe-repo</id> - <name>Typesafe Repository</name> - <url>http://repo.typesafe.com/typesafe/releases/</url> + <id>akka-repo</id> + <name>Akka Repository</name> + <url>http://repo.akka.io/releases/</url> <releases> <enabled>true</enabled> </releases> diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index af8b5ba01745b59f2c8c82f3c6c6bdb58a01fa45..25c232837358d6d9b68cd185d53febbf5a607f7b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -44,6 +44,9 @@ object SparkBuild extends Build { transitiveClassifiers in Scope.GlobalScope := Seq("sources"), testListeners <<= target.map(t => Seq(new eu.henkelmann.sbt.JUnitXmlTestsListener(t.getAbsolutePath))), + // shared between both core and streaming. + resolvers ++= Seq("Akka Repository" at "http://repo.akka.io/releases/"), + // For Sonatype publishing resolvers ++= Seq("sonatype-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots", "sonatype-staging" at "https://oss.sonatype.org/service/local/staging/deploy/maven2/"), @@ -114,7 +117,6 @@ object SparkBuild extends Build { def coreSettings = sharedSettings ++ Seq( name := "spark-core", resolvers ++= Seq( - "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/", "JBoss Repository" at "http://repository.jboss.org/nexus/content/repositories/releases/", "Spray Repository" at "http://repo.spray.cc/", "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/", @@ -155,9 +157,7 @@ object SparkBuild extends Build { def examplesSettings = sharedSettings ++ Seq( name := "spark-examples", - libraryDependencies ++= Seq( - "org.twitter4j" % "twitter4j-stream" % "3.0.3" - ) + libraryDependencies ++= Seq("com.twitter" % "algebird-core_2.9.2" % "0.1.8") ) def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel") @@ -166,7 +166,9 @@ object SparkBuild extends Build { name := "spark-streaming", libraryDependencies ++= Seq( "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile", - "com.github.sgroschupf" % "zkclient" % "0.1" + "com.github.sgroschupf" % "zkclient" % "0.1", + "org.twitter4j" % "twitter4j-stream" % "3.0.3", + "com.typesafe.akka" % "akka-zeromq" % "2.0.3" ) ) ++ assemblySettings ++ extraAssemblySettings diff --git a/pyspark b/pyspark index ab7f4f50c01bd7f34b49a05e8bde24fa756e573d..d662e90287edfb3c45e95cf8f3544ed76ddc82ac 100755 --- a/pyspark +++ b/pyspark @@ -36,4 +36,9 @@ if [[ "$SPARK_LAUNCH_WITH_SCALA" != "0" ]] ; then export SPARK_LAUNCH_WITH_SCALA=1 fi -exec "$PYSPARK_PYTHON" "$@" +if [[ "$IPYTHON" = "1" ]] ; then + export PYSPARK_PYTHON="ipython" + exec "$PYSPARK_PYTHON" -i -c "%run $PYTHONSTARTUP" +else + exec "$PYSPARK_PYTHON" "$@" +fi diff --git a/python/pyspark/join.py b/python/pyspark/join.py index 7036c479806679d539a4779df92ed1bfd94bf694..5f4294fb1b7778090584041286822eabbb0e1efe 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -32,13 +32,13 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ -def _do_python_join(rdd, other, numSplits, dispatch): +def _do_python_join(rdd, other, numPartitions, dispatch): vs = rdd.map(lambda (k, v): (k, (1, v))) ws = other.map(lambda (k, v): (k, (2, v))) - return vs.union(ws).groupByKey(numSplits).flatMapValues(dispatch) + return vs.union(ws).groupByKey(numPartitions).flatMapValues(dispatch) -def python_join(rdd, other, numSplits): +def python_join(rdd, other, numPartitions): def dispatch(seq): vbuf, wbuf = [], [] for (n, v) in seq: @@ -47,10 +47,10 @@ def python_join(rdd, other, numSplits): elif n == 2: wbuf.append(v) return [(v, w) for v in vbuf for w in wbuf] - return _do_python_join(rdd, other, numSplits, dispatch) + return _do_python_join(rdd, other, numPartitions, dispatch) -def python_right_outer_join(rdd, other, numSplits): +def python_right_outer_join(rdd, other, numPartitions): def dispatch(seq): vbuf, wbuf = [], [] for (n, v) in seq: @@ -61,10 +61,10 @@ def python_right_outer_join(rdd, other, numSplits): if not vbuf: vbuf.append(None) return [(v, w) for v in vbuf for w in wbuf] - return _do_python_join(rdd, other, numSplits, dispatch) + return _do_python_join(rdd, other, numPartitions, dispatch) -def python_left_outer_join(rdd, other, numSplits): +def python_left_outer_join(rdd, other, numPartitions): def dispatch(seq): vbuf, wbuf = [], [] for (n, v) in seq: @@ -75,10 +75,10 @@ def python_left_outer_join(rdd, other, numSplits): if not wbuf: wbuf.append(None) return [(v, w) for v in vbuf for w in wbuf] - return _do_python_join(rdd, other, numSplits, dispatch) + return _do_python_join(rdd, other, numPartitions, dispatch) -def python_cogroup(rdd, other, numSplits): +def python_cogroup(rdd, other, numPartitions): vs = rdd.map(lambda (k, v): (k, (1, v))) ws = other.map(lambda (k, v): (k, (2, v))) def dispatch(seq): @@ -89,4 +89,4 @@ def python_cogroup(rdd, other, numSplits): elif n == 2: wbuf.append(v) return (vbuf, wbuf) - return vs.union(ws).groupByKey(numSplits).mapValues(dispatch) + return vs.union(ws).groupByKey(numPartitions).mapValues(dispatch) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 4cda6cf661197f662351bb5f08154b7aae1241f2..172ed85fab9267b7cfffbec8c0474701c9faf23f 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -215,7 +215,7 @@ class RDD(object): yield pair return java_cartesian.flatMap(unpack_batches) - def groupBy(self, f, numSplits=None): + def groupBy(self, f, numPartitions=None): """ Return an RDD of grouped items. @@ -224,7 +224,7 @@ class RDD(object): >>> sorted([(x, sorted(y)) for (x, y) in result]) [(0, [2, 8]), (1, [1, 1, 3, 5])] """ - return self.map(lambda x: (f(x), x)).groupByKey(numSplits) + return self.map(lambda x: (f(x), x)).groupByKey(numPartitions) def pipe(self, command, env={}): """ @@ -274,8 +274,8 @@ class RDD(object): def reduce(self, f): """ - Reduces the elements of this RDD using the specified associative binary - operator. + Reduces the elements of this RDD using the specified commutative and + associative binary operator. >>> from operator import add >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add) @@ -422,22 +422,22 @@ class RDD(object): """ return dict(self.collect()) - def reduceByKey(self, func, numSplits=None): + def reduceByKey(self, func, numPartitions=None): """ Merge the values for each key using an associative reduce function. This will also perform the merging locally on each mapper before sending results to a reducer, similarly to a "combiner" in MapReduce. - Output will be hash-partitioned with C{numSplits} splits, or the - default parallelism level if C{numSplits} is not specified. + Output will be hash-partitioned with C{numPartitions} partitions, or + the default parallelism level if C{numPartitions} is not specified. >>> from operator import add >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> sorted(rdd.reduceByKey(add).collect()) [('a', 2), ('b', 1)] """ - return self.combineByKey(lambda x: x, func, func, numSplits) + return self.combineByKey(lambda x: x, func, func, numPartitions) def reduceByKeyLocally(self, func): """ @@ -474,7 +474,7 @@ class RDD(object): """ return self.map(lambda x: x[0]).countByValue() - def join(self, other, numSplits=None): + def join(self, other, numPartitions=None): """ Return an RDD containing all pairs of elements with matching keys in C{self} and C{other}. @@ -489,9 +489,9 @@ class RDD(object): >>> sorted(x.join(y).collect()) [('a', (1, 2)), ('a', (1, 3))] """ - return python_join(self, other, numSplits) + return python_join(self, other, numPartitions) - def leftOuterJoin(self, other, numSplits=None): + def leftOuterJoin(self, other, numPartitions=None): """ Perform a left outer join of C{self} and C{other}. @@ -506,9 +506,9 @@ class RDD(object): >>> sorted(x.leftOuterJoin(y).collect()) [('a', (1, 2)), ('b', (4, None))] """ - return python_left_outer_join(self, other, numSplits) + return python_left_outer_join(self, other, numPartitions) - def rightOuterJoin(self, other, numSplits=None): + def rightOuterJoin(self, other, numPartitions=None): """ Perform a right outer join of C{self} and C{other}. @@ -523,10 +523,10 @@ class RDD(object): >>> sorted(y.rightOuterJoin(x).collect()) [('a', (2, 1)), ('b', (None, 4))] """ - return python_right_outer_join(self, other, numSplits) + return python_right_outer_join(self, other, numPartitions) # TODO: add option to control map-side combining - def partitionBy(self, numSplits, partitionFunc=hash): + def partitionBy(self, numPartitions, partitionFunc=hash): """ Return a copy of the RDD partitioned using the specified partitioner. @@ -535,22 +535,22 @@ class RDD(object): >>> set(sets[0]).intersection(set(sets[1])) set([]) """ - if numSplits is None: - numSplits = self.ctx.defaultParallelism + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism # Transferring O(n) objects to Java is too expensive. Instead, we'll - # form the hash buckets in Python, transferring O(numSplits) objects + # form the hash buckets in Python, transferring O(numPartitions) objects # to Java. Each object is a (splitNumber, [objects]) pair. def add_shuffle_key(split, iterator): buckets = defaultdict(list) for (k, v) in iterator: - buckets[partitionFunc(k) % numSplits].append((k, v)) + buckets[partitionFunc(k) % numPartitions].append((k, v)) for (split, items) in buckets.iteritems(): yield str(split) yield dump_pickle(Batch(items)) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() - partitioner = self.ctx._jvm.PythonPartitioner(numSplits, + partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, id(partitionFunc)) jrdd = pairRDD.partitionBy(partitioner).values() rdd = RDD(jrdd, self.ctx) @@ -561,7 +561,7 @@ class RDD(object): # TODO: add control over map-side aggregation def combineByKey(self, createCombiner, mergeValue, mergeCombiners, - numSplits=None): + numPartitions=None): """ Generic function to combine the elements for each key using a custom set of aggregation functions. @@ -586,8 +586,8 @@ class RDD(object): >>> sorted(x.combineByKey(str, add, add).collect()) [('a', '11'), ('b', '1')] """ - if numSplits is None: - numSplits = self.ctx.defaultParallelism + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism def combineLocally(iterator): combiners = {} for (k, v) in iterator: @@ -597,7 +597,7 @@ class RDD(object): combiners[k] = mergeValue(combiners[k], v) return combiners.iteritems() locally_combined = self.mapPartitions(combineLocally) - shuffled = locally_combined.partitionBy(numSplits) + shuffled = locally_combined.partitionBy(numPartitions) def _mergeCombiners(iterator): combiners = {} for (k, v) in iterator: @@ -609,10 +609,10 @@ class RDD(object): return shuffled.mapPartitions(_mergeCombiners) # TODO: support variant with custom partitioner - def groupByKey(self, numSplits=None): + def groupByKey(self, numPartitions=None): """ Group the values for each key in the RDD into a single sequence. - Hash-partitions the resulting RDD with into numSplits partitions. + Hash-partitions the resulting RDD with into numPartitions partitions. >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> sorted(x.groupByKey().collect()) @@ -630,7 +630,7 @@ class RDD(object): return a + b return self.combineByKey(createCombiner, mergeValue, mergeCombiners, - numSplits) + numPartitions) # TODO: add tests def flatMapValues(self, f): @@ -659,7 +659,7 @@ class RDD(object): return self.cogroup(other) # TODO: add variant with custom parittioner - def cogroup(self, other, numSplits=None): + def cogroup(self, other, numPartitions=None): """ For each key k in C{self} or C{other}, return a resulting RDD that contains a tuple with the list of values for that key in C{self} as well @@ -670,7 +670,7 @@ class RDD(object): >>> sorted(x.cogroup(y).collect()) [('a', ([1], [2])), ('b', ([4], []))] """ - return python_cogroup(self, other, numSplits) + return python_cogroup(self, other, numPartitions) # TODO: `lookup` is disabled because we can't make direct comparisons based # on the key; we need to compare the hash of the key to the hash of the diff --git a/run b/run index a0946294497a842adeef1f287377fb3f0b6c719a..ecbf7673c660f76c7294f76e9f28104b51453602 100755 --- a/run +++ b/run @@ -13,6 +13,38 @@ if [ -e $FWDIR/conf/spark-env.sh ] ; then . $FWDIR/conf/spark-env.sh fi +if [ -z "$1" ]; then + echo "Usage: run <spark-class> [<args>]" >&2 + exit 1 +fi + +# If this is a standalone cluster daemon, reset SPARK_JAVA_OPTS and SPARK_MEM to reasonable +# values for that; it doesn't need a lot +if [ "$1" = "spark.deploy.master.Master" -o "$1" = "spark.deploy.worker.Worker" ]; then + SPARK_MEM=${SPARK_DAEMON_MEMORY:-512m} + SPARK_JAVA_OPTS=$SPARK_DAEMON_JAVA_OPTS # Empty by default +fi + + +# Add java opts for master, worker, executor. The opts maybe null +case "$1" in + 'spark.deploy.master.Master') + SPARK_JAVA_OPTS+=" $SPARK_MASTER_OPTS" + ;; + 'spark.deploy.worker.Worker') + SPARK_JAVA_OPTS+=" $SPARK_WORKER_OPTS" + ;; + 'spark.executor.StandaloneExecutorBackend') + SPARK_JAVA_OPTS+=" $SPARK_EXECUTOR_OPTS" + ;; + 'spark.executor.MesosExecutorBackend') + SPARK_JAVA_OPTS+=" $SPARK_EXECUTOR_OPTS" + ;; + 'spark.repl.Main') + SPARK_JAVA_OPTS+=" $SPARK_REPL_OPTS" + ;; +esac + if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then if [ `command -v scala` ]; then RUNNER="scala" @@ -79,11 +111,13 @@ CLASSPATH+=":$FWDIR/conf" CLASSPATH+=":$CORE_DIR/target/scala-$SCALA_VERSION/classes" if [ -n "$SPARK_TESTING" ] ; then CLASSPATH+=":$CORE_DIR/target/scala-$SCALA_VERSION/test-classes" + CLASSPATH+=":$STREAMING_DIR/target/scala-$SCALA_VERSION/test-classes" fi CLASSPATH+=":$CORE_DIR/src/main/resources" CLASSPATH+=":$REPL_DIR/target/scala-$SCALA_VERSION/classes" CLASSPATH+=":$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes" CLASSPATH+=":$STREAMING_DIR/target/scala-$SCALA_VERSION/classes" +CLASSPATH+=":$STREAMING_DIR/lib/org/apache/kafka/kafka/0.7.2-spark/*" # <-- our in-project Kafka Jar if [ -e "$FWDIR/lib_managed" ]; then CLASSPATH+=":$FWDIR/lib_managed/jars/*" CLASSPATH+=":$FWDIR/lib_managed/bundles/*" diff --git a/run2.cmd b/run2.cmd index 67f1e465e47b37d0b1a63cc6d2a22e2993462147..705a4d1ff6830af502631da58eb1b0ee105c2274 100644 --- a/run2.cmd +++ b/run2.cmd @@ -11,9 +11,22 @@ set SPARK_HOME=%FWDIR% rem Load environment variables from conf\spark-env.cmd, if it exists if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" +rem Test that an argument was given +if not "x%1"=="x" goto arg_given + echo Usage: run ^<spark-class^> [^<args^>] + goto exit +:arg_given + +set RUNNING_DAEMON=0 +if "%1"=="spark.deploy.master.Master" set RUNNING_DAEMON=1 +if "%1"=="spark.deploy.worker.Worker" set RUNNING_DAEMON=1 +if "x%SPARK_DAEMON_MEMORY%" == "x" set SPARK_DAEMON_MEMORY=512m +if "%RUNNING_DAEMON%"=="1" set SPARK_MEM=%SPARK_DAEMON_MEMORY% +if "%RUNNING_DAEMON%"=="1" set SPARK_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% + rem Check that SCALA_HOME has been specified if not "x%SCALA_HOME%"=="x" goto scala_exists - echo "SCALA_HOME is not set" + echo SCALA_HOME is not set goto exit :scala_exists @@ -34,16 +47,19 @@ set CORE_DIR=%FWDIR%core set REPL_DIR=%FWDIR%repl set EXAMPLES_DIR=%FWDIR%examples set BAGEL_DIR=%FWDIR%bagel +set STREAMING_DIR=%FWDIR%streaming set PYSPARK_DIR=%FWDIR%python rem Build up classpath set CLASSPATH=%SPARK_CLASSPATH%;%MESOS_CLASSPATH%;%FWDIR%conf;%CORE_DIR%\target\scala-%SCALA_VERSION%\classes set CLASSPATH=%CLASSPATH%;%CORE_DIR%\target\scala-%SCALA_VERSION%\test-classes;%CORE_DIR%\src\main\resources +set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\classes;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\test-classes +set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\lib\org\apache\kafka\kafka\0.7.2-spark\* set CLASSPATH=%CLASSPATH%;%REPL_DIR%\target\scala-%SCALA_VERSION%\classes;%EXAMPLES_DIR%\target\scala-%SCALA_VERSION%\classes -for /R "%FWDIR%\lib_managed\jars" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j -for /R "%FWDIR%\lib_managed\bundles" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j -for /R "%REPL_DIR%\lib" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j -for /R "%PYSPARK_DIR%\lib" %%j in (*.jar) do set CLASSPATH=!CLASSPATH!;%%j +set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\jars\* +set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\bundles\* +set CLASSPATH=%CLASSPATH%;%FWDIR%repl\lib\* +set CLASSPATH=%CLASSPATH%;%FWDIR%python\lib\* set CLASSPATH=%CLASSPATH%;%BAGEL_DIR%\target\scala-%SCALA_VERSION%\classes rem Figure out whether to run our class with java or with the scala launcher. diff --git a/sbt/sbt.cmd b/sbt/sbt.cmd index 6b289ab44761c11846fb212984a5f022b9d52a3d..ce3ae7017458b1ee3703ad7cd0a928ebda2756af 100644 --- a/sbt/sbt.cmd +++ b/sbt/sbt.cmd @@ -2,4 +2,4 @@ set EXTRA_ARGS= if not "%MESOS_HOME%x"=="x" set EXTRA_ARGS=-Djava.library.path=%MESOS_HOME%\lib\java set SPARK_HOME=%~dp0.. -java -Xmx1200M -XX:MaxPermSize=200m %EXTRA_ARGS% -jar %SPARK_HOME%\sbt\sbt-launch-*.jar "%*" +java -Xmx1200M -XX:MaxPermSize=200m %EXTRA_ARGS% -jar %SPARK_HOME%\sbt\sbt-launch-0.11.3-2.jar "%*" diff --git a/streaming/pom.xml b/streaming/pom.xml index 6ee7e59df39d16b1bf47ad23604bfe7f80f40f53..15523eadcb78198d84a5a5d5e21518e514c3dc78 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -47,7 +47,16 @@ <artifactId>zkclient</artifactId> <version>0.1</version> </dependency> - + <dependency> + <groupId>org.twitter4j</groupId> + <artifactId>twitter4j-stream</artifactId> + <version>3.0.3</version> + </dependency> + <dependency> + <groupId>com.typesafe.akka</groupId> + <artifactId>akka-zeromq</artifactId> + <version>2.0.3</version> + </dependency> <dependency> <groupId>org.scalatest</groupId> <artifactId>scalatest_${scala.version}</artifactId> diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index 2f3adb39c2506590c9d65b80285eb6054deba111..e7a392fbbf346fe2c55b49acc7ebeb5413aa520c 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -6,18 +6,21 @@ import org.apache.hadoop.fs.{FileUtil, Path} import org.apache.hadoop.conf.Configuration import java.io._ +import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} +import java.util.concurrent.Executors private[streaming] class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) extends Logging with Serializable { val master = ssc.sc.master - val framework = ssc.sc.jobName + val framework = ssc.sc.appName val sparkHome = ssc.sc.sparkHome val jars = ssc.sc.jars val graph = ssc.graph val checkpointDir = ssc.checkpointDir - val checkpointDuration: Duration = ssc.checkpointDuration + val checkpointDuration = ssc.checkpointDuration + val pendingTimes = ssc.scheduler.jobManager.getPendingTimes() def validate() { assert(master != null, "Checkpoint.master is null") @@ -37,32 +40,50 @@ class CheckpointWriter(checkpointDir: String) extends Logging { val conf = new Configuration() var fs = file.getFileSystem(conf) val maxAttempts = 3 + val executor = Executors.newFixedThreadPool(1) - def write(checkpoint: Checkpoint) { - // TODO: maybe do this in a different thread from the main stream execution thread - var attempts = 0 - while (attempts < maxAttempts) { - attempts += 1 - try { - logDebug("Saving checkpoint for time " + checkpoint.checkpointTime + " to file '" + file + "'") - if (fs.exists(file)) { - val bkFile = new Path(file.getParent, file.getName + ".bk") - FileUtil.copy(fs, file, fs, bkFile, true, true, conf) - logDebug("Moved existing checkpoint file to " + bkFile) + class CheckpointWriteHandler(checkpointTime: Time, bytes: Array[Byte]) extends Runnable { + def run() { + var attempts = 0 + val startTime = System.currentTimeMillis() + while (attempts < maxAttempts) { + attempts += 1 + try { + logDebug("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'") + if (fs.exists(file)) { + val bkFile = new Path(file.getParent, file.getName + ".bk") + FileUtil.copy(fs, file, fs, bkFile, true, true, conf) + logDebug("Moved existing checkpoint file to " + bkFile) + } + val fos = fs.create(file) + fos.write(bytes) + fos.close() + fos.close() + val finishTime = System.currentTimeMillis(); + logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + file + + "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " milliseconds") + return + } catch { + case ioe: IOException => + logWarning("Error writing checkpoint to file in " + attempts + " attempts", ioe) } - val fos = fs.create(file) - val oos = new ObjectOutputStream(fos) - oos.writeObject(checkpoint) - oos.close() - logInfo("Checkpoint for time " + checkpoint.checkpointTime + " saved to file '" + file + "'") - fos.close() - return - } catch { - case ioe: IOException => - logWarning("Error writing checkpoint to file in " + attempts + " attempts", ioe) } + logError("Could not write checkpoint for time " + checkpointTime + " to file '" + file + "'") } - logError("Could not write checkpoint for time " + checkpoint.checkpointTime + " to file '" + file + "'") + } + + def write(checkpoint: Checkpoint) { + val bos = new ByteArrayOutputStream() + val zos = new LZFOutputStream(bos) + val oos = new ObjectOutputStream(zos) + oos.writeObject(checkpoint) + oos.close() + bos.close() + executor.execute(new CheckpointWriteHandler(checkpoint.checkpointTime, bos.toByteArray)) + } + + def stop() { + executor.shutdown() } } @@ -84,7 +105,8 @@ object CheckpointReader extends Logging { // of ObjectInputStream is used to explicitly use the current thread's default class // loader to find and load classes. This is a well know Java issue and has popped up // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) - val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader) + val zis = new LZFInputStream(fis) + val ois = new ObjectInputStreamWithLoader(zis, Thread.currentThread().getContextClassLoader) val cp = ois.readObject.asInstanceOf[Checkpoint] ois.close() fs.close() diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 352f83fe0c33bdfee1093efd90aec6206ad825c4..e1be5ef51cc9cd8cc4c71c52b6f8dcb820dbb761 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -12,7 +12,7 @@ import scala.collection.mutable.HashMap import java.io.{ObjectInputStream, IOException, ObjectOutputStream} -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.conf.Configuration /** @@ -75,7 +75,7 @@ abstract class DStream[T: ClassManifest] ( // Checkpoint details protected[streaming] val mustCheckpoint = false protected[streaming] var checkpointDuration: Duration = null - protected[streaming] var checkpointData = new DStreamCheckpointData(HashMap[Time, Any]()) + protected[streaming] val checkpointData = new DStreamCheckpointData(this) // Reference to whole DStream graph protected[streaming] var graph: DStreamGraph = null @@ -85,10 +85,10 @@ abstract class DStream[T: ClassManifest] ( // Duration for which the DStream requires its parent DStream to remember each RDD created protected[streaming] def parentRememberDuration = rememberDuration - /** Returns the StreamingContext associated with this DStream */ - def context() = ssc + /** Return the StreamingContext associated with this DStream */ + def context = ssc - /** Persists the RDDs of this DStream with the given storage level */ + /** Persist the RDDs of this DStream with the given storage level */ def persist(level: StorageLevel): DStream[T] = { if (this.isInitialized) { throw new UnsupportedOperationException( @@ -132,7 +132,7 @@ abstract class DStream[T: ClassManifest] ( // Set the checkpoint interval to be slideDuration or 10 seconds, which ever is larger if (mustCheckpoint && checkpointDuration == null) { - checkpointDuration = slideDuration.max(Seconds(10)) + checkpointDuration = slideDuration * math.ceil(Seconds(10) / slideDuration).toInt logInfo("Checkpoint interval automatically set to " + checkpointDuration) } @@ -159,7 +159,7 @@ abstract class DStream[T: ClassManifest] ( ) assert( - checkpointDuration == null || ssc.sc.checkpointDir.isDefined, + checkpointDuration == null || context.sparkContext.checkpointDir.isDefined, "The checkpoint directory has not been set. Please use StreamingContext.checkpoint()" + " or SparkContext.checkpoint() to set the checkpoint directory." ) @@ -238,13 +238,15 @@ abstract class DStream[T: ClassManifest] ( dependencies.foreach(_.remember(parentRememberDuration)) } - /** This method checks whether the 'time' is valid wrt slideDuration for generating RDD */ + /** Checks whether the 'time' is valid wrt slideDuration for generating RDD */ protected def isTimeValid(time: Time): Boolean = { if (!isInitialized) { throw new Exception (this + " has not been initialized") } else if (time <= zeroTime || ! (time - zeroTime).isMultipleOf(slideDuration)) { + logInfo("Time " + time + " is invalid as zeroTime is " + zeroTime + " and slideDuration is " + slideDuration + " and difference is " + (time - zeroTime)) false } else { + logInfo("Time " + time + " is valid") true } } @@ -292,14 +294,14 @@ abstract class DStream[T: ClassManifest] ( * Generate a SparkStreaming job for the given time. This is an internal method that * should not be called directly. This default implementation creates a job * that materializes the corresponding RDD. Subclasses of DStream may override this - * (eg. ForEachDStream). + * to generate their own jobs. */ protected[streaming] def generateJob(time: Time): Option[Job] = { getOrCompute(time) match { case Some(rdd) => { val jobFunc = () => { - val emptyFunc = { (iterator: Iterator[T]) => {} } - ssc.sc.runJob(rdd, emptyFunc) + val emptyFunc = { (iterator: Iterator[T]) => {} } + context.sparkContext.runJob(rdd, emptyFunc) } Some(new Job(time, jobFunc)) } @@ -308,20 +310,18 @@ abstract class DStream[T: ClassManifest] ( } /** - * Dereference RDDs that are older than rememberDuration. + * Clear metadata that are older than `rememberDuration` of this DStream. + * This is an internal method that should not be called directly. This default + * implementation clears the old generated RDDs. Subclasses of DStream may override + * this to clear their own metadata along with the generated RDDs. */ - protected[streaming] def forgetOldRDDs(time: Time) { - val keys = generatedRDDs.keys + protected[streaming] def clearOldMetadata(time: Time) { var numForgotten = 0 - keys.foreach(t => { - if (t <= (time - rememberDuration)) { - generatedRDDs.remove(t) - numForgotten += 1 - logInfo("Forgot RDD of time " + t + " from " + this) - } - }) - logInfo("Forgot " + numForgotten + " RDDs from " + this) - dependencies.foreach(_.forgetOldRDDs(time)) + val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration)) + generatedRDDs --= oldRDDs.keys + logInfo("Cleared " + oldRDDs.size + " RDDs that were older than " + + (time - rememberDuration) + ": " + oldRDDs.keys.mkString(", ")) + dependencies.foreach(_.clearOldMetadata(time)) } /* Adds metadata to the Stream while it is running. @@ -342,40 +342,10 @@ abstract class DStream[T: ClassManifest] ( */ protected[streaming] def updateCheckpointData(currentTime: Time) { logInfo("Updating checkpoint data for time " + currentTime) - - // Get the checkpointed RDDs from the generated RDDs - val newRdds = generatedRDDs.filter(_._2.getCheckpointFile.isDefined) - .map(x => (x._1, x._2.getCheckpointFile.get)) - - // Make a copy of the existing checkpoint data (checkpointed RDDs) - val oldRdds = checkpointData.rdds.clone() - - // If the new checkpoint data has checkpoints then replace existing with the new one - if (newRdds.size > 0) { - checkpointData.rdds.clear() - checkpointData.rdds ++= newRdds - } - - // Make parent DStreams update their checkpoint data + checkpointData.update() dependencies.foreach(_.updateCheckpointData(currentTime)) - - // TODO: remove this, this is just for debugging - newRdds.foreach { - case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") } - } - - if (newRdds.size > 0) { - (oldRdds -- newRdds.keySet).foreach { - case (time, data) => { - val path = new Path(data.toString) - val fs = path.getFileSystem(new Configuration()) - fs.delete(path, true) - logInfo("Deleted checkpoint file '" + path + "' for time " + time) - } - } - } - logInfo("Updated checkpoint data for time " + currentTime + ", " + checkpointData.rdds.size + " checkpoints, " - + "[" + checkpointData.rdds.mkString(",") + "]") + checkpointData.cleanup() + logDebug("Updated checkpoint data for time " + currentTime + ": " + checkpointData) } /** @@ -386,14 +356,8 @@ abstract class DStream[T: ClassManifest] ( */ protected[streaming] def restoreCheckpointData() { // Create RDDs from the checkpoint data - logInfo("Restoring checkpoint data from " + checkpointData.rdds.size + " checkpointed RDDs") - checkpointData.rdds.foreach { - case(time, data) => { - logInfo("Restoring checkpointed RDD for time " + time + " from file '" + data.toString + "'") - val rdd = ssc.sc.checkpointFile[T](data.toString) - generatedRDDs += ((time, rdd)) - } - } + logInfo("Restoring checkpoint data") + checkpointData.restore() dependencies.foreach(_.restoreCheckpointData()) logInfo("Restored checkpoint data") } @@ -433,7 +397,7 @@ abstract class DStream[T: ClassManifest] ( /** Return a new DStream by applying a function to all elements of this DStream. */ def map[U: ClassManifest](mapFunc: T => U): DStream[U] = { - new MappedDStream(this, ssc.sc.clean(mapFunc)) + new MappedDStream(this, context.sparkContext.clean(mapFunc)) } /** @@ -441,7 +405,7 @@ abstract class DStream[T: ClassManifest] ( * and then flattening the results */ def flatMap[U: ClassManifest](flatMapFunc: T => Traversable[U]): DStream[U] = { - new FlatMappedDStream(this, ssc.sc.clean(flatMapFunc)) + new FlatMappedDStream(this, context.sparkContext.clean(flatMapFunc)) } /** Return a new DStream containing only the elements that satisfy a predicate. */ @@ -463,7 +427,7 @@ abstract class DStream[T: ClassManifest] ( mapPartFunc: Iterator[T] => Iterator[U], preservePartitioning: Boolean = false ): DStream[U] = { - new MapPartitionedDStream(this, ssc.sc.clean(mapPartFunc), preservePartitioning) + new MapPartitionedDStream(this, context.sparkContext.clean(mapPartFunc), preservePartitioning) } /** @@ -479,6 +443,15 @@ abstract class DStream[T: ClassManifest] ( */ def count(): DStream[Long] = this.map(_ => 1L).reduce(_ + _) + /** + * Return a new DStream in which each RDD contains the counts of each distinct value in + * each RDD of this DStream. Hash partitioning is used to generate + * the RDDs with `numPartitions` partitions (Spark's default number of partitions if + * `numPartitions` not specified). + */ + def countByValue(numPartitions: Int = ssc.sc.defaultParallelism): DStream[(T, Long)] = + this.map(x => (x, 1L)).reduceByKey((x: Long, y: Long) => x + y, numPartitions) + /** * Apply a function to each RDD in this DStream. This is an output operator, so * this DStream will be registered as an output stream and therefore materialized. @@ -492,7 +465,7 @@ abstract class DStream[T: ClassManifest] ( * this DStream will be registered as an output stream and therefore materialized. */ def foreach(foreachFunc: (RDD[T], Time) => Unit) { - val newStream = new ForEachDStream(this, ssc.sc.clean(foreachFunc)) + val newStream = new ForEachDStream(this, context.sparkContext.clean(foreachFunc)) ssc.registerOutputStream(newStream) newStream } @@ -510,7 +483,7 @@ abstract class DStream[T: ClassManifest] ( * on each RDD of this DStream. */ def transform[U: ClassManifest](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = { - new TransformedDStream(this, ssc.sc.clean(transformFunc)) + new TransformedDStream(this, context.sparkContext.clean(transformFunc)) } /** @@ -527,19 +500,21 @@ abstract class DStream[T: ClassManifest] ( if (first11.size > 10) println("...") println() } - val newStream = new ForEachDStream(this, ssc.sc.clean(foreachFunc)) + val newStream = new ForEachDStream(this, context.sparkContext.clean(foreachFunc)) ssc.registerOutputStream(newStream) } /** - * Return a new DStream which is computed based on windowed batches of this DStream. - * The new DStream generates RDDs with the same interval as this DStream. + * Return a new DStream in which each RDD contains all the elements in seen in a + * sliding window of time over this DStream. The new DStream generates RDDs with + * the same interval as this DStream. * @param windowDuration width of the window; must be a multiple of this DStream's interval. */ def window(windowDuration: Duration): DStream[T] = window(windowDuration, this.slideDuration) /** - * Return a new DStream which is computed based on windowed batches of this DStream. + * Return a new DStream in which each RDD contains all the elements in seen in a + * sliding window of time over this DStream. * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which @@ -550,28 +525,40 @@ abstract class DStream[T: ClassManifest] ( new WindowedDStream(this, windowDuration, slideDuration) } - /** - * Return a new DStream which computed based on tumbling window on this DStream. - * This is equivalent to window(batchTime, batchTime). - * @param batchDuration tumbling window duration; must be a multiple of this DStream's - * batching interval - */ - def tumble(batchDuration: Duration): DStream[T] = window(batchDuration, batchDuration) - /** * Return a new DStream in which each RDD has a single element generated by reducing all - * elements in a window over this DStream. windowDuration and slideDuration are as defined - * in the window() operation. This is equivalent to - * window(windowDuration, slideDuration).reduce(reduceFunc) + * elements in a sliding window over this DStream. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval */ def reduceByWindow( reduceFunc: (T, T) => T, windowDuration: Duration, slideDuration: Duration ): DStream[T] = { - this.window(windowDuration, slideDuration).reduce(reduceFunc) + this.reduce(reduceFunc).window(windowDuration, slideDuration).reduce(reduceFunc) } + /** + * Return a new DStream in which each RDD has a single element generated by reducing all + * elements in a sliding window over this DStream. However, the reduction is done incrementally + * using the old window's reduced value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * This is more efficient than reduceByWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ def reduceByWindow( reduceFunc: (T, T) => T, invReduceFunc: (T, T) => T, @@ -585,13 +572,46 @@ abstract class DStream[T: ClassManifest] ( /** * Return a new DStream in which each RDD has a single element generated by counting the number - * of elements in a window over this DStream. windowDuration and slideDuration are as defined in the - * window() operation. This is equivalent to window(windowDuration, slideDuration).count() + * of elements in a sliding window over this DStream. Hash partitioning is used to generate the RDDs with + * Spark's default number of partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval */ def countByWindow(windowDuration: Duration, slideDuration: Duration): DStream[Long] = { this.map(_ => 1L).reduceByWindow(_ + _, _ - _, windowDuration, slideDuration) } + /** + * Return a new DStream in which each RDD contains the count of distinct elements in + * RDDs in a sliding window over this DStream. Hash partitioning is used to generate + * the RDDs with `numPartitions` partitions (Spark's default number of partitions if + * `numPartitions` not specified). + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions number of partitions of each RDD in the new DStream. + */ + def countByValueAndWindow( + windowDuration: Duration, + slideDuration: Duration, + numPartitions: Int = ssc.sc.defaultParallelism + ): DStream[(T, Long)] = { + + this.map(x => (x, 1L)).reduceByKeyAndWindow( + (x: Long, y: Long) => x + y, + (x: Long, y: Long) => x - y, + windowDuration, + slideDuration, + numPartitions, + (x: (T, Long)) => x._2 != 0L + ) + } + /** * Return a new DStream by unifying data of another DStream with this DStream. * @param that Another DStream having the same slideDuration as this DStream. @@ -609,16 +629,21 @@ abstract class DStream[T: ClassManifest] ( * Return all the RDDs between 'fromTime' to 'toTime' (both included) */ def slice(fromTime: Time, toTime: Time): Seq[RDD[T]] = { - val rdds = new ArrayBuffer[RDD[T]]() - var time = toTime.floor(slideDuration) - while (time >= zeroTime && time >= fromTime) { - getOrCompute(time) match { - case Some(rdd) => rdds += rdd - case None => //throw new Exception("Could not get RDD for time " + time) - } - time -= slideDuration + if (!(fromTime - zeroTime).isMultipleOf(slideDuration)) { + logWarning("fromTime (" + fromTime + ") is not a multiple of slideDuration (" + slideDuration + ")") + } + if (!(toTime - zeroTime).isMultipleOf(slideDuration)) { + logWarning("toTime (" + fromTime + ") is not a multiple of slideDuration (" + slideDuration + ")") } - rdds.toSeq + val alignedToTime = toTime.floor(slideDuration) + val alignedFromTime = fromTime.floor(slideDuration) + + logInfo("Slicing from " + fromTime + " to " + toTime + + " (aligned to " + alignedFromTime + " and " + alignedToTime + ")") + + alignedFromTime.to(alignedToTime, slideDuration).flatMap(time => { + if (time >= zeroTime) getOrCompute(time) else None + }) } /** @@ -651,7 +676,3 @@ abstract class DStream[T: ClassManifest] ( ssc.registerOutputStream(this) } } - -private[streaming] -case class DStreamCheckpointData(rdds: HashMap[Time, Any]) - diff --git a/streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala b/streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala new file mode 100644 index 0000000000000000000000000000000000000000..6b0fade7c64d79ca6e6ddda5dff73ca07851e1be --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/DStreamCheckpointData.scala @@ -0,0 +1,93 @@ +package spark.streaming + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.conf.Configuration +import collection.mutable.HashMap +import spark.Logging + + + +private[streaming] +class DStreamCheckpointData[T: ClassManifest] (dstream: DStream[T]) + extends Serializable with Logging { + protected val data = new HashMap[Time, AnyRef]() + + @transient private var fileSystem : FileSystem = null + @transient private var lastCheckpointFiles: HashMap[Time, String] = null + + protected[streaming] def checkpointFiles = data.asInstanceOf[HashMap[Time, String]] + + /** + * Updates the checkpoint data of the DStream. This gets called every time + * the graph checkpoint is initiated. Default implementation records the + * checkpoint files to which the generate RDDs of the DStream has been saved. + */ + def update() { + + // Get the checkpointed RDDs from the generated RDDs + val newCheckpointFiles = dstream.generatedRDDs.filter(_._2.getCheckpointFile.isDefined) + .map(x => (x._1, x._2.getCheckpointFile.get)) + + // Make a copy of the existing checkpoint data (checkpointed RDDs) + lastCheckpointFiles = checkpointFiles.clone() + + // If the new checkpoint data has checkpoints then replace existing with the new one + if (newCheckpointFiles.size > 0) { + checkpointFiles.clear() + checkpointFiles ++= newCheckpointFiles + } + + // TODO: remove this, this is just for debugging + newCheckpointFiles.foreach { + case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") } + } + } + + /** + * Cleanup old checkpoint data. This gets called every time the graph + * checkpoint is initiated, but after `update` is called. Default + * implementation, cleans up old checkpoint files. + */ + def cleanup() { + // If there is at least on checkpoint file in the current checkpoint files, + // then delete the old checkpoint files. + if (checkpointFiles.size > 0 && lastCheckpointFiles != null) { + (lastCheckpointFiles -- checkpointFiles.keySet).foreach { + case (time, file) => { + try { + val path = new Path(file) + if (fileSystem == null) { + fileSystem = path.getFileSystem(new Configuration()) + } + fileSystem.delete(path, true) + logInfo("Deleted checkpoint file '" + file + "' for time " + time) + } catch { + case e: Exception => + logWarning("Error deleting old checkpoint file '" + file + "' for time " + time, e) + } + } + } + } + } + + /** + * Restore the checkpoint data. This gets called once when the DStream graph + * (along with its DStreams) are being restored from a graph checkpoint file. + * Default implementation restores the RDDs from their checkpoint files. + */ + def restore() { + // Create RDDs from the checkpoint data + checkpointFiles.foreach { + case(time, file) => { + logInfo("Restoring checkpointed RDD for time " + time + " from file '" + file + "'") + dstream.generatedRDDs += ((time, dstream.context.sparkContext.checkpointFile[T](file))) + } + } + } + + override def toString() = { + "[\n" + checkpointFiles.size + " checkpoint files \n" + checkpointFiles.mkString("\n") + "\n]" + } +} + diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index bc4a40d7bccd4cd1b277e48119bb9a677a9d999d..adb7f3a24d25f6fcbd1453c2e75b56b9a22d10b4 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -11,17 +11,20 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { private val inputStreams = new ArrayBuffer[InputDStream[_]]() private val outputStreams = new ArrayBuffer[DStream[_]]() - private[streaming] var zeroTime: Time = null - private[streaming] var batchDuration: Duration = null - private[streaming] var rememberDuration: Duration = null - private[streaming] var checkpointInProgress = false + var rememberDuration: Duration = null + var checkpointInProgress = false - private[streaming] def start(time: Time) { + var zeroTime: Time = null + var startTime: Time = null + var batchDuration: Duration = null + + def start(time: Time) { this.synchronized { if (zeroTime != null) { throw new Exception("DStream graph computation already started") } zeroTime = time + startTime = time outputStreams.foreach(_.initialize(zeroTime)) outputStreams.foreach(_.remember(rememberDuration)) outputStreams.foreach(_.validate) @@ -29,19 +32,23 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { } } - private[streaming] def stop() { + def restart(time: Time) { + this.synchronized { startTime = time } + } + + def stop() { this.synchronized { inputStreams.par.foreach(_.stop()) } } - private[streaming] def setContext(ssc: StreamingContext) { + def setContext(ssc: StreamingContext) { this.synchronized { outputStreams.foreach(_.setContext(ssc)) } } - private[streaming] def setBatchDuration(duration: Duration) { + def setBatchDuration(duration: Duration) { this.synchronized { if (batchDuration != null) { throw new Exception("Batch duration already set as " + batchDuration + @@ -51,59 +58,68 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { batchDuration = duration } - private[streaming] def remember(duration: Duration) { + def remember(duration: Duration) { this.synchronized { if (rememberDuration != null) { throw new Exception("Batch duration already set as " + batchDuration + ". cannot set it again.") } + rememberDuration = duration } - rememberDuration = duration } - private[streaming] def addInputStream(inputStream: InputDStream[_]) { + def addInputStream(inputStream: InputDStream[_]) { this.synchronized { inputStream.setGraph(this) inputStreams += inputStream } } - private[streaming] def addOutputStream(outputStream: DStream[_]) { + def addOutputStream(outputStream: DStream[_]) { this.synchronized { outputStream.setGraph(this) outputStreams += outputStream } } - private[streaming] def getInputStreams() = this.synchronized { inputStreams.toArray } + def getInputStreams() = this.synchronized { inputStreams.toArray } - private[streaming] def getOutputStreams() = this.synchronized { outputStreams.toArray } + def getOutputStreams() = this.synchronized { outputStreams.toArray } - private[streaming] def generateRDDs(time: Time): Seq[Job] = { + def generateJobs(time: Time): Seq[Job] = { this.synchronized { - outputStreams.flatMap(outputStream => outputStream.generateJob(time)) + logInfo("Generating jobs for time " + time) + val jobs = outputStreams.flatMap(outputStream => outputStream.generateJob(time)) + logInfo("Generated " + jobs.length + " jobs for time " + time) + jobs } } - private[streaming] def forgetOldRDDs(time: Time) { + def clearOldMetadata(time: Time) { this.synchronized { - outputStreams.foreach(_.forgetOldRDDs(time)) + logInfo("Clearing old metadata for time " + time) + outputStreams.foreach(_.clearOldMetadata(time)) + logInfo("Cleared old metadata for time " + time) } } - private[streaming] def updateCheckpointData(time: Time) { + def updateCheckpointData(time: Time) { this.synchronized { + logInfo("Updating checkpoint data for time " + time) outputStreams.foreach(_.updateCheckpointData(time)) + logInfo("Updated checkpoint data for time " + time) } } - private[streaming] def restoreCheckpointData() { + def restoreCheckpointData() { this.synchronized { + logInfo("Restoring checkpoint data") outputStreams.foreach(_.restoreCheckpointData()) + logInfo("Restored checkpoint data") } } - private[streaming] def validate() { + def validate() { this.synchronized { assert(batchDuration != null, "Batch duration has not been set") //assert(batchDuration >= Milliseconds(100), "Batch duration of " + batchDuration + " is very low") diff --git a/streaming/src/main/scala/spark/streaming/Duration.scala b/streaming/src/main/scala/spark/streaming/Duration.scala index e4dc579a170a4214ffa3a9ee959b160471a26dbe..ee26206e249a8abf6bf9f502b7dac6eaef65e0e6 100644 --- a/streaming/src/main/scala/spark/streaming/Duration.scala +++ b/streaming/src/main/scala/spark/streaming/Duration.scala @@ -16,7 +16,7 @@ case class Duration (private val millis: Long) { def * (times: Int): Duration = new Duration(millis * times) - def / (that: Duration): Long = millis / that.millis + def / (that: Duration): Double = millis.toDouble / that.millis.toDouble def isMultipleOf(that: Duration): Boolean = (this.millis % that.millis == 0) diff --git a/streaming/src/main/scala/spark/streaming/Interval.scala b/streaming/src/main/scala/spark/streaming/Interval.scala index dc21dfb72228801bb1db803901ec81e6b76ae411..6a8b81760e35b5fe81dc169155d7678c107bba9f 100644 --- a/streaming/src/main/scala/spark/streaming/Interval.scala +++ b/streaming/src/main/scala/spark/streaming/Interval.scala @@ -30,6 +30,7 @@ class Interval(val beginTime: Time, val endTime: Time) { override def toString = "[" + beginTime + ", " + endTime + "]" } +private[streaming] object Interval { def currentInterval(duration: Duration): Interval = { val time = new Time(System.currentTimeMillis) diff --git a/streaming/src/main/scala/spark/streaming/JobManager.scala b/streaming/src/main/scala/spark/streaming/JobManager.scala index 3b910538e028925bd9125db6ab4e2045717a2c9e..7696c4a592bf6c3c9f5885da1e779f39f7ff3101 100644 --- a/streaming/src/main/scala/spark/streaming/JobManager.scala +++ b/streaming/src/main/scala/spark/streaming/JobManager.scala @@ -3,6 +3,8 @@ package spark.streaming import spark.Logging import spark.SparkEnv import java.util.concurrent.Executors +import collection.mutable.HashMap +import collection.mutable.ArrayBuffer private[streaming] @@ -13,21 +15,57 @@ class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging { SparkEnv.set(ssc.env) try { val timeTaken = job.run() - logInfo("Total delay: %.5f s for job %s (execution: %.5f s)".format( - (System.currentTimeMillis() - job.time.milliseconds) / 1000.0, job.id, timeTaken / 1000.0)) + logInfo("Total delay: %.5f s for job %s of time %s (execution: %.5f s)".format( + (System.currentTimeMillis() - job.time.milliseconds) / 1000.0, job.id, job.time.milliseconds, timeTaken / 1000.0)) } catch { case e: Exception => logError("Running " + job + " failed", e) } + clearJob(job) } } initLogging() val jobExecutor = Executors.newFixedThreadPool(numThreads) - + val jobs = new HashMap[Time, ArrayBuffer[Job]] + def runJob(job: Job) { + jobs.synchronized { + jobs.getOrElseUpdate(job.time, new ArrayBuffer[Job]) += job + } jobExecutor.execute(new JobHandler(ssc, job)) logInfo("Added " + job + " to queue") } + + def stop() { + jobExecutor.shutdown() + } + + private def clearJob(job: Job) { + var timeCleared = false + val time = job.time + jobs.synchronized { + val jobsOfTime = jobs.get(time) + if (jobsOfTime.isDefined) { + jobsOfTime.get -= job + if (jobsOfTime.get.isEmpty) { + jobs -= time + timeCleared = true + } + } else { + throw new Exception("Job finished for time " + job.time + + " but time does not exist in jobs") + } + } + if (timeCleared) { + ssc.scheduler.clearOldMetadata(time) + } + } + + def getPendingTimes(): Array[Time] = { + jobs.synchronized { + jobs.keySet.toArray + } + } } diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index e4152f3a61bb5e3240b38fd706a6e784686152cd..b159d26c02b2d4383253110d1d589fc2a6b4261b 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -4,6 +4,7 @@ import spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver} import spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError} import spark.Logging import spark.SparkEnv +import spark.SparkContext._ import scala.collection.mutable.HashMap import scala.collection.mutable.Queue @@ -23,7 +24,7 @@ private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) ext */ private[streaming] class NetworkInputTracker( - @transient ssc: StreamingContext, + @transient ssc: StreamingContext, @transient networkInputStreams: Array[NetworkInputDStream[_]]) extends Logging { @@ -65,12 +66,12 @@ class NetworkInputTracker( def receive = { case RegisterReceiver(streamId, receiverActor) => { if (!networkInputStreamMap.contains(streamId)) { - throw new Exception("Register received for unexpected id " + streamId) + throw new Exception("Register received for unexpected id " + streamId) } receiverInfo += ((streamId, receiverActor)) logInfo("Registered receiver for network stream " + streamId + " from " + sender.path.address) sender ! true - } + } case AddBlocks(streamId, blockIds, metadata) => { val tmp = receivedBlockIds.synchronized { if (!receivedBlockIds.contains(streamId)) { @@ -85,7 +86,7 @@ class NetworkInputTracker( } case DeregisterReceiver(streamId, msg) => { receiverInfo -= streamId - logInfo("De-registered receiver for network stream " + streamId + logError("De-registered receiver for network stream " + streamId + " with message " + msg) //TODO: Do something about the corresponding NetworkInputDStream } @@ -95,8 +96,8 @@ class NetworkInputTracker( /** This thread class runs all the receivers on the cluster. */ class ReceiverExecutor extends Thread { val env = ssc.env - - override def run() { + + override def run() { try { SparkEnv.set(env) startReceivers() @@ -113,7 +114,7 @@ class NetworkInputTracker( */ def startReceivers() { val receivers = networkInputStreams.map(nis => { - val rcvr = nis.createReceiver() + val rcvr = nis.getReceiver() rcvr.setStreamId(nis.id) rcvr }) @@ -138,10 +139,14 @@ class NetworkInputTracker( } iterator.next().start() } + // Run the dummy Spark job to ensure that all slaves have registered. + // This avoids all the receivers to be scheduled on the same node. + ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect() + // Distribute the receivers and start them - ssc.sc.runJob(tempRDD, startReceiver) + ssc.sparkContext.runJob(tempRDD, startReceiver) } - + /** Stops the receivers. */ def stopReceivers() { // Signal the receivers to stop diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index fbcf06112676d78c1ebc04a982a09145c379afdd..3ec922957d635135f21f279005405ec556c9d898 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -18,15 +18,15 @@ import org.apache.hadoop.conf.Configuration class PairDStreamFunctions[K: ClassManifest, V: ClassManifest](self: DStream[(K,V)]) extends Serializable { - - def ssc = self.ssc + + private[streaming] def ssc = self.ssc private[streaming] def defaultPartitioner(numPartitions: Int = self.ssc.sc.defaultParallelism) = { new HashPartitioner(numPartitions) } /** - * Create a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to + * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with Spark's default number of partitions. */ def groupByKey(): DStream[(K, Seq[V])] = { @@ -34,7 +34,7 @@ extends Serializable { } /** - * Create a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to + * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with `numPartitions` partitions. */ def groupByKey(numPartitions: Int): DStream[(K, Seq[V])] = { @@ -42,7 +42,7 @@ extends Serializable { } /** - * Create a new DStream by applying `groupByKey` on each RDD. The supplied [[spark.Partitioner]] + * Return a new DStream by applying `groupByKey` on each RDD. The supplied [[spark.Partitioner]] * is used to control the partitioning of each RDD. */ def groupByKey(partitioner: Partitioner): DStream[(K, Seq[V])] = { @@ -54,7 +54,7 @@ extends Serializable { } /** - * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are * merged using the associative reduce function. Hash partitioning is used to generate the RDDs * with Spark's default number of partitions. */ @@ -63,7 +63,7 @@ extends Serializable { } /** - * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are * merged using the supplied reduce function. Hash partitioning is used to generate the RDDs * with `numPartitions` partitions. */ @@ -72,7 +72,7 @@ extends Serializable { } /** - * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are * merged using the supplied reduce function. [[spark.Partitioner]] is used to control the * partitioning of each RDD. */ @@ -82,7 +82,7 @@ extends Serializable { } /** - * Combine elements of each key in DStream's RDDs using custom function. This is similar to the + * Combine elements of each key in DStream's RDDs using custom functions. This is similar to the * combineByKey for RDDs. Please refer to combineByKey in [[spark.PairRDDFunctions]] for more * information. */ @@ -95,15 +95,7 @@ extends Serializable { } /** - * Create a new DStream by counting the number of values of each key in each RDD. Hash - * partitioning is used to generate the RDDs with Spark's `numPartitions` partitions. - */ - def countByKey(numPartitions: Int = self.ssc.sc.defaultParallelism): DStream[(K, Long)] = { - self.map(x => (x._1, 1L)).reduceByKey((x: Long, y: Long) => x + y, numPartitions) - } - - /** - * Creates a new DStream by applying `groupByKey` over a sliding window. This is similar to + * Return a new DStream by applying `groupByKey` over a sliding window. This is similar to * `DStream.groupByKey()` but applies it over a sliding window. The new DStream generates RDDs * with the same interval as this DStream. Hash partitioning is used to generate the RDDs with * Spark's default number of partitions. @@ -115,7 +107,7 @@ extends Serializable { } /** - * Create a new DStream by applying `groupByKey` over a sliding window. Similar to + * Return a new DStream by applying `groupByKey` over a sliding window. Similar to * `DStream.groupByKey()`, but applies it over a sliding window. Hash partitioning is used to * generate the RDDs with Spark's default number of partitions. * @param windowDuration width of the window; must be a multiple of this DStream's @@ -129,7 +121,7 @@ extends Serializable { } /** - * Create a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * Return a new DStream by applying `groupByKey` over a sliding window on `this` DStream. * Similar to `DStream.groupByKey()`, but applies it over a sliding window. * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. * @param windowDuration width of the window; must be a multiple of this DStream's @@ -137,7 +129,8 @@ extends Serializable { * @param slideDuration sliding interval of the window (i.e., the interval after which * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval - * @param numPartitions Number of partitions of each RDD in the new DStream. + * @param numPartitions number of partitions of each RDD in the new DStream; if not specified + * then Spark's default number of partitions will be used */ def groupByKeyAndWindow( windowDuration: Duration, @@ -155,7 +148,7 @@ extends Serializable { * @param slideDuration sliding interval of the window (i.e., the interval after which * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval - * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + * @param partitioner partitioner for controlling the partitioning of each RDD in the new DStream. */ def groupByKeyAndWindow( windowDuration: Duration, @@ -166,7 +159,7 @@ extends Serializable { } /** - * Create a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. + * Return a new DStream by applying `reduceByKey` over a sliding window on `this` DStream. * Similar to `DStream.reduceByKey()`, but applies it over a sliding window. The new DStream * generates RDDs with the same interval as this DStream. Hash partitioning is used to generate * the RDDs with Spark's default number of partitions. @@ -182,7 +175,7 @@ extends Serializable { } /** - * Create a new DStream by applying `reduceByKey` over a sliding window. This is similar to + * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to * generate the RDDs with Spark's default number of partitions. * @param reduceFunc associative reduce function @@ -201,7 +194,7 @@ extends Serializable { } /** - * Create a new DStream by applying `reduceByKey` over a sliding window. This is similar to + * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to * generate the RDDs with `numPartitions` partitions. * @param reduceFunc associative reduce function @@ -210,10 +203,10 @@ extends Serializable { * @param slideDuration sliding interval of the window (i.e., the interval after which * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval - * @param numPartitions Number of partitions of each RDD in the new DStream. + * @param numPartitions number of partitions of each RDD in the new DStream. */ def reduceByKeyAndWindow( - reduceFunc: (V, V) => V, + reduceFunc: (V, V) => V, windowDuration: Duration, slideDuration: Duration, numPartitions: Int @@ -222,7 +215,7 @@ extends Serializable { } /** - * Create a new DStream by applying `reduceByKey` over a sliding window. Similar to + * Return a new DStream by applying `reduceByKey` over a sliding window. Similar to * `DStream.reduceByKey()`, but applies it over a sliding window. * @param reduceFunc associative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's @@ -230,7 +223,8 @@ extends Serializable { * @param slideDuration sliding interval of the window (i.e., the interval after which * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval - * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + * @param partitioner partitioner for controlling the partitioning of each RDD + * in the new DStream. */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, @@ -245,118 +239,78 @@ extends Serializable { } /** - * Create a new DStream by reducing over a using incremental computation. - * The reduced value of over a new window is calculated using the old window's reduce value : + * Return a new DStream by applying incremental `reduceByKey` over a sliding window. + * The reduced value of over a new window is calculated using the old window's reduced value : * 1. reduce the new values that entered the window (e.g., adding new counts) + * * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) - * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. + * + * This is more efficient than reduceByKeyAndWindow without "inverse reduce" function. * However, it is applicable to only "invertible reduce functions". * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. * @param reduceFunc associative reduce function - * @param invReduceFunc inverse function + * @param invReduceFunc inverse reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval + * @param filterFunc Optional function to filter expired key-value pairs; + * only pairs that satisfy the function are retained */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, windowDuration: Duration, - slideDuration: Duration + slideDuration: Duration = self.slideDuration, + numPartitions: Int = ssc.sc.defaultParallelism, + filterFunc: ((K, V)) => Boolean = null ): DStream[(K, V)] = { reduceByKeyAndWindow( - reduceFunc, invReduceFunc, windowDuration, slideDuration, defaultPartitioner()) - } - - /** - * Create a new DStream by reducing over a using incremental computation. - * The reduced value of over a new window is calculated using the old window's reduce value : - * 1. reduce the new values that entered the window (e.g., adding new counts) - * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) - * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. - * However, it is applicable to only "invertible reduce functions". - * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. - * @param reduceFunc associative reduce function - * @param invReduceFunc inverse function - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - * @param numPartitions Number of partitions of each RDD in the new DStream. - */ - def reduceByKeyAndWindow( - reduceFunc: (V, V) => V, - invReduceFunc: (V, V) => V, - windowDuration: Duration, - slideDuration: Duration, - numPartitions: Int - ): DStream[(K, V)] = { - - reduceByKeyAndWindow( - reduceFunc, invReduceFunc, windowDuration, slideDuration, defaultPartitioner(numPartitions)) + reduceFunc, invReduceFunc, windowDuration, + slideDuration, defaultPartitioner(numPartitions), filterFunc + ) } /** - * Create a new DStream by reducing over a using incremental computation. - * The reduced value of over a new window is calculated using the old window's reduce value : + * Return a new DStream by applying incremental `reduceByKey` over a sliding window. + * The reduced value of over a new window is calculated using the old window's reduced value : * 1. reduce the new values that entered the window (e.g., adding new counts) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) - * This is more efficient that reduceByKeyAndWindow without "inverse reduce" function. + * This is more efficient than reduceByKeyAndWindow without "inverse reduce" function. * However, it is applicable to only "invertible reduce functions". - * @param reduceFunc associative reduce function - * @param invReduceFunc inverse function + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse reduce function * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval - * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + * @param partitioner partitioner for controlling the partitioning of each RDD in the new DStream. + * @param filterFunc Optional function to filter expired key-value pairs; + * only pairs that satisfy the function are retained */ def reduceByKeyAndWindow( reduceFunc: (V, V) => V, invReduceFunc: (V, V) => V, windowDuration: Duration, slideDuration: Duration, - partitioner: Partitioner + partitioner: Partitioner, + filterFunc: ((K, V)) => Boolean ): DStream[(K, V)] = { val cleanedReduceFunc = ssc.sc.clean(reduceFunc) val cleanedInvReduceFunc = ssc.sc.clean(invReduceFunc) + val cleanedFilterFunc = if (filterFunc != null) Some(ssc.sc.clean(filterFunc)) else None new ReducedWindowedDStream[K, V]( - self, cleanedReduceFunc, cleanedInvReduceFunc, windowDuration, slideDuration, partitioner) - } - - /** - * Create a new DStream by counting the number of values for each key over a window. - * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - * @param numPartitions Number of partitions of each RDD in the new DStream. - */ - def countByKeyAndWindow( - windowDuration: Duration, - slideDuration: Duration, - numPartitions: Int = self.ssc.sc.defaultParallelism - ): DStream[(K, Long)] = { - - self.map(x => (x._1, 1L)).reduceByKeyAndWindow( - (x: Long, y: Long) => x + y, - (x: Long, y: Long) => x - y, - windowDuration, - slideDuration, - numPartitions + self, cleanedReduceFunc, cleanedInvReduceFunc, cleanedFilterFunc, + windowDuration, slideDuration, partitioner ) } /** - * Create a new "state" DStream where the state for each key is updated by applying + * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. * @param updateFunc State update function. If `this` function returns None, then @@ -370,7 +324,7 @@ extends Serializable { } /** - * Create a new "state" DStream where the state for each key is updated by applying + * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. * @param updateFunc State update function. If `this` function returns None, then @@ -405,7 +359,7 @@ extends Serializable { } /** - * Create a new "state" DStream where the state for each key is updated by applying + * Return a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. * [[spark.Paxrtitioner]] is used to control the partitioning of each RDD. * @param updateFunc State update function. If `this` function returns None, then @@ -447,7 +401,7 @@ extends Serializable { } /** - * Cogroup `this` DStream with `other` DStream. For each key k in corresponding RDDs of `this` + * Cogroup `this` DStream with `other` DStream using a partitioner. For each key k in corresponding RDDs of `this` * or `other` DStreams, the generated RDD will contains a tuple with the list of values for that * key in both RDDs. Partitioner is used to partition each generated RDD. */ @@ -457,7 +411,7 @@ extends Serializable { ): DStream[(K, (Seq[V], Seq[W]))] = { val cgd = new CoGroupedDStream[K]( - Seq(self.asInstanceOf[DStream[(_, _)]], other.asInstanceOf[DStream[(_, _)]]), + Seq(self.asInstanceOf[DStream[(K, _)]], other.asInstanceOf[DStream[(K, _)]]), partitioner ) val pdfs = new PairDStreamFunctions[K, Seq[Seq[_]]](cgd)( diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index c04ed37de8b21708bdc7ef295441129386f4965b..1c4b22a8981c8db8ecb2a28fcfae2e301d000627 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -9,11 +9,8 @@ class Scheduler(ssc: StreamingContext) extends Logging { initLogging() - val graph = ssc.graph - val concurrentJobs = System.getProperty("spark.streaming.concurrentJobs", "1").toInt val jobManager = new JobManager(ssc, concurrentJobs) - val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) { new CheckpointWriter(ssc.checkpointDir) } else { @@ -23,54 +20,93 @@ class Scheduler(ssc: StreamingContext) extends Logging { val clockClass = System.getProperty("spark.streaming.clock", "spark.streaming.util.SystemClock") val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock] val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, - longTime => generateRDDs(new Time(longTime))) + longTime => generateJobs(new Time(longTime))) + val graph = ssc.graph + var latestTime: Time = null - def start() { - // If context was started from checkpoint, then restart timer such that - // this timer's triggers occur at the same time as the original timer. - // Otherwise just start the timer from scratch, and initialize graph based - // on this first trigger time of the timer. + def start() = synchronized { if (ssc.isCheckpointPresent) { - // If manual clock is being used for testing, then - // either set the manual clock to the last checkpointed time, - // or if the property is defined set it to that time - if (clock.isInstanceOf[ManualClock]) { - val lastTime = ssc.getInitialCheckpoint.checkpointTime.milliseconds - val jumpTime = System.getProperty("spark.streaming.manualClock.jump", "0").toLong - clock.asInstanceOf[ManualClock].setTime(lastTime + jumpTime) - } - timer.restart(graph.zeroTime.milliseconds) - logInfo("Scheduler's timer restarted") + restart() } else { - val firstTime = new Time(timer.start()) - graph.start(firstTime - ssc.graph.batchDuration) - logInfo("Scheduler's timer started") + startFirstTime() } logInfo("Scheduler started") } - def stop() { + def stop() = synchronized { timer.stop() - graph.stop() + jobManager.stop() + if (checkpointWriter != null) checkpointWriter.stop() + ssc.graph.stop() logInfo("Scheduler stopped") } - - private def generateRDDs(time: Time) { + + private def startFirstTime() { + val startTime = new Time(timer.getStartTime()) + graph.start(startTime - graph.batchDuration) + timer.start(startTime.milliseconds) + logInfo("Scheduler's timer started at " + startTime) + } + + private def restart() { + + // If manual clock is being used for testing, then + // either set the manual clock to the last checkpointed time, + // or if the property is defined set it to that time + if (clock.isInstanceOf[ManualClock]) { + val lastTime = ssc.initialCheckpoint.checkpointTime.milliseconds + val jumpTime = System.getProperty("spark.streaming.manualClock.jump", "0").toLong + clock.asInstanceOf[ManualClock].setTime(lastTime + jumpTime) + } + + val batchDuration = ssc.graph.batchDuration + + // Batches when the master was down, that is, + // between the checkpoint and current restart time + val checkpointTime = ssc.initialCheckpoint.checkpointTime + val restartTime = new Time(timer.getRestartTime(graph.zeroTime.milliseconds)) + val downTimes = checkpointTime.until(restartTime, batchDuration) + logInfo("Batches during down time: " + downTimes.mkString(", ")) + + // Batches that were unprocessed before failure + val pendingTimes = ssc.initialCheckpoint.pendingTimes + logInfo("Batches pending processing: " + pendingTimes.mkString(", ")) + // Reschedule jobs for these times + val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering) + logInfo("Batches to reschedule: " + timesToReschedule.mkString(", ")) + timesToReschedule.foreach(time => + graph.generateJobs(time).foreach(jobManager.runJob) + ) + + // Restart the timer + timer.start(restartTime.milliseconds) + logInfo("Scheduler's timer restarted at " + restartTime) + } + + /** Generate jobs and perform checkpoint for the given `time`. */ + def generateJobs(time: Time) { SparkEnv.set(ssc.env) logInfo("\n-----------------------------------------------------\n") - graph.generateRDDs(time).foreach(jobManager.runJob) - graph.forgetOldRDDs(time) + graph.generateJobs(time).foreach(jobManager.runJob) + latestTime = time + doCheckpoint(time) + } + + /** + * Clear old metadata assuming jobs of `time` have finished processing. + * And also perform checkpoint. + */ + def clearOldMetadata(time: Time) { + ssc.graph.clearOldMetadata(time) doCheckpoint(time) - logInfo("Generated RDDs for time " + time) } - private def doCheckpoint(time: Time) { + /** Perform checkpoint for the give `time`. */ + def doCheckpoint(time: Time) = synchronized { if (ssc.checkpointDuration != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) { - val startTime = System.currentTimeMillis() + logInfo("Checkpointing graph for time " + time) ssc.graph.updateCheckpointData(time) checkpointWriter.write(new Checkpoint(ssc, time)) - val stopTime = System.currentTimeMillis() - logInfo("Checkpointing the graph took " + (stopTime - startTime) + " ms") } } } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 37ba524b4897ea43088d86fe0fbd64d72ad9023a..25c67b279b7d2feb38a5b5d520f3ed0ab73d5f5c 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -1,10 +1,19 @@ package spark.streaming +import akka.actor.Props +import akka.actor.SupervisorStrategy +import akka.zeromq.Subscribe + import spark.streaming.dstream._ import spark.{RDD, Logging, SparkEnv, SparkContext} +import spark.streaming.receivers.ActorReceiver +import spark.streaming.receivers.ReceiverSupervisorStrategy +import spark.streaming.receivers.ZeroMQReceiver import spark.storage.StorageLevel import spark.util.MetadataCleaner +import spark.streaming.receivers.ActorReceiver + import scala.collection.mutable.Queue @@ -17,6 +26,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.hadoop.fs.Path import java.util.UUID +import twitter4j.Status /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -30,23 +40,23 @@ class StreamingContext private ( ) extends Logging { /** - * Creates a StreamingContext using an existing SparkContext. + * Create a StreamingContext using an existing SparkContext. * @param sparkContext Existing SparkContext * @param batchDuration The time interval at which streaming data will be divided into batches */ def this(sparkContext: SparkContext, batchDuration: Duration) = this(sparkContext, null, batchDuration) /** - * Creates a StreamingContext by providing the details necessary for creating a new SparkContext. + * Create a StreamingContext by providing the details necessary for creating a new SparkContext. * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - * @param frameworkName A name for your job, to display on the cluster web UI + * @param appName A name for your job, to display on the cluster web UI * @param batchDuration The time interval at which streaming data will be divided into batches */ - def this(master: String, frameworkName: String, batchDuration: Duration) = - this(StreamingContext.createNewSparkContext(master, frameworkName), null, batchDuration) + def this(master: String, appName: String, batchDuration: Duration) = + this(StreamingContext.createNewSparkContext(master, appName), null, batchDuration) /** - * Re-creates a StreamingContext from a checkpoint file. + * Re-create a StreamingContext from a checkpoint file. * @param path Path either to the directory that was specified as the checkpoint directory, or * to the checkpoint file 'graph' or 'graph.bk'. */ @@ -61,7 +71,7 @@ class StreamingContext private ( protected[streaming] val isCheckpointPresent = (cp_ != null) - val sc: SparkContext = { + protected[streaming] val sc: SparkContext = { if (isCheckpointPresent) { new SparkContext(cp_.master, cp_.framework, cp_.sparkHome, cp_.jars) } else { @@ -101,7 +111,12 @@ class StreamingContext private ( protected[streaming] var scheduler: Scheduler = null /** - * Sets each DStreams in this context to remember RDDs it generated in the last given duration. + * Return the associated Spark context + */ + def sparkContext = sc + + /** + * Set each DStreams in this context to remember RDDs it generated in the last given duration. * DStreams remember RDDs only for a limited duration of time and releases them for garbage * collection. This method allows the developer to specify how to long to remember the RDDs ( * if the developer wishes to query old data outside the DStream computation). @@ -112,71 +127,119 @@ class StreamingContext private ( } /** - * Sets the context to periodically checkpoint the DStream operations for master - * fault-tolerance. By default, the graph will be checkpointed every batch interval. + * Set the context to periodically checkpoint the DStream operations for master + * fault-tolerance. The graph will be checkpointed every batch interval. * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored - * @param interval checkpoint interval */ - def checkpoint(directory: String, interval: Duration = null) { + def checkpoint(directory: String) { if (directory != null) { sc.setCheckpointDir(StreamingContext.getSparkCheckpointDir(directory)) checkpointDir = directory - checkpointDuration = interval } else { checkpointDir = null - checkpointDuration = null } } - protected[streaming] def getInitialCheckpoint(): Checkpoint = { + protected[streaming] def initialCheckpoint: Checkpoint = { if (isCheckpointPresent) cp_ else null } protected[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() + /** + * Create an input stream with any arbitrary user implemented network receiver. + * @param receiver Custom implementation of NetworkReceiver + */ + def networkStream[T: ClassManifest]( + receiver: NetworkReceiver[T]): DStream[T] = { + val inputStream = new PluggableInputDStream[T](this, + receiver) + graph.addInputStream(inputStream) + inputStream + } + + /** + * Create an input stream with any arbitrary user implemented actor receiver. + * @param props Props object defining creation of the actor + * @param name Name of the actor + * @param storageLevel RDD storage level. Defaults to memory-only. + * + * @note An important point to note: + * Since Actor may exist outside the spark framework, It is thus user's responsibility + * to ensure the type safety, i.e parametrized type of data received and actorStream + * should be same. + */ + def actorStream[T: ClassManifest]( + props: Props, + name: String, + storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2, + supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy): DStream[T] = { + networkStream(new ActorReceiver[T](props, name, storageLevel, supervisorStrategy)) + } + + /** + * Create an input stream that receives messages pushed by a zeromq publisher. + * @param publisherUrl Url of remote zeromq publisher + * @param subscribe topic to subscribe to + * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each frame has sequence + * of byte thus it needs the converter(which might be deserializer of bytes) + * to translate from sequence of sequence of bytes, where sequence refer to a frame + * and sub sequence refer to its payload. + * @param storageLevel RDD storage level. Defaults to memory-only. + */ + def zeroMQStream[T: ClassManifest]( + publisherUrl:String, + subscribe: Subscribe, + bytesToObjects: Seq[Seq[Byte]] ⇒ Iterator[T], + storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2, + supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy + ): DStream[T] = { + actorStream(Props(new ZeroMQReceiver(publisherUrl,subscribe,bytesToObjects)), + "ZeroMQReceiver", storageLevel, supervisorStrategy) + } + /** * Create an input stream that pulls messages form a Kafka Broker. - * @param hostname Zookeper hostname. - * @param port Zookeper port. + * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param initialOffsets Optional initial offsets for each of the partitions to consume. * By default the value is pulled from zookeper. - * @param storageLevel RDD storage level. Defaults to memory-only. + * @param storageLevel Storage level to use for storing the received objects + * (default: StorageLevel.MEMORY_AND_DISK_SER_2) */ def kafkaStream[T: ClassManifest]( - hostname: String, - port: Int, + zkQuorum: String, groupId: String, topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long] = Map[KafkaPartitionKey, Long](), storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 ): DStream[T] = { - val inputStream = new KafkaInputDStream[T](this, hostname, port, groupId, topics, initialOffsets, storageLevel) + val inputStream = new KafkaInputDStream[T](this, zkQuorum, groupId, topics, initialOffsets, storageLevel) registerInputStream(inputStream) inputStream } /** - * Create a input stream from network source hostname:port. Data is received using - * a TCP socket and the receive bytes is interpreted as UTF8 encoded \n delimited + * Create a input stream from TCP source hostname:port. Data is received using + * a TCP socket and the receive bytes is interpreted as UTF8 encoded `\n` delimited * lines. * @param hostname Hostname to connect to for receiving data * @param port Port to connect to for receiving data * @param storageLevel Storage level to use for storing the received objects * (default: StorageLevel.MEMORY_AND_DISK_SER_2) */ - def networkTextStream( + def socketTextStream( hostname: String, port: Int, storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 ): DStream[String] = { - networkStream[String](hostname, port, SocketReceiver.bytesToLines, storageLevel) + socketStream[String](hostname, port, SocketReceiver.bytesToLines, storageLevel) } /** - * Create a input stream from network source hostname:port. Data is received using + * Create a input stream from TCP source hostname:port. Data is received using * a TCP socket and the receive bytes it interepreted as object using the given * converter. * @param hostname Hostname to connect to for receiving data @@ -185,7 +248,7 @@ class StreamingContext private ( * @param storageLevel Storage level to use for storing the received objects * @tparam T Type of the objects received (after converting bytes to objects) */ - def networkStream[T: ClassManifest]( + def socketStream[T: ClassManifest]( hostname: String, port: Int, converter: (InputStream) => Iterator[T], @@ -197,7 +260,7 @@ class StreamingContext private ( } /** - * Creates a input stream from a Flume source. + * Create a input stream from a Flume source. * @param hostname Hostname of the slave machine to which the flume data will be sent * @param port Port of the slave machine to which the flume data will be sent * @param storageLevel Storage level to use for storing the received objects @@ -222,7 +285,7 @@ class StreamingContext private ( * @param storageLevel Storage level to use for storing the received objects * @tparam T Type of the objects in the received blocks */ - def rawNetworkStream[T: ClassManifest]( + def rawSocketStream[T: ClassManifest]( hostname: String, port: Int, storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 @@ -233,7 +296,7 @@ class StreamingContext private ( } /** - * Creates a input stream that monitors a Hadoop-compatible filesystem + * Create a input stream that monitors a Hadoop-compatible filesystem * for new files and reads them using the given key-value types and input format. * File names starting with . are ignored. * @param directory HDFS directory to monitor for new file @@ -252,7 +315,7 @@ class StreamingContext private ( } /** - * Creates a input stream that monitors a Hadoop-compatible filesystem + * Create a input stream that monitors a Hadoop-compatible filesystem * for new files and reads them using the given key-value types and input format. * @param directory HDFS directory to monitor for new file * @param filter Function to filter paths to process @@ -271,9 +334,8 @@ class StreamingContext private ( inputStream } - /** - * Creates a input stream that monitors a Hadoop-compatible filesystem + * Create a input stream that monitors a Hadoop-compatible filesystem * for new files and reads them as text files (using key as LongWritable, value * as Text and input format as TextInputFormat). File names starting with . are ignored. * @param directory HDFS directory to monitor for new file @@ -283,17 +345,49 @@ class StreamingContext private ( } /** - * Creates a input stream from an queue of RDDs. In each batch, + * Create a input stream that returns tweets received from Twitter. + * @param username Twitter username + * @param password Twitter password + * @param filters Set of filter strings to get only those tweets that match them + * @param storageLevel Storage level to use for storing the received objects + */ + def twitterStream( + username: String, + password: String, + filters: Seq[String] = Nil, + storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 + ): DStream[Status] = { + val inputStream = new TwitterInputDStream(this, username, password, filters, storageLevel) + registerInputStream(inputStream) + inputStream + } + + /** + * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval - * @param defaultRDD Default RDD is returned by the DStream when the queue is empty * @tparam T Type of objects in the RDD */ def queueStream[T: ClassManifest]( queue: Queue[RDD[T]], - oneAtATime: Boolean = true, - defaultRDD: RDD[T] = null + oneAtATime: Boolean = true + ): DStream[T] = { + queueStream(queue, oneAtATime, sc.makeRDD(Seq[T](), 1)) + } + + /** + * Create an input stream from a queue of RDDs. In each batch, + * it will process either one or all of the RDDs returned by the queue. + * @param queue Queue of RDDs + * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval + * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. Set as null if no RDD should be returned when empty + * @tparam T Type of objects in the RDD + */ + def queueStream[T: ClassManifest]( + queue: Queue[RDD[T]], + oneAtATime: Boolean, + defaultRDD: RDD[T] ): DStream[T] = { val inputStream = new QueueInputDStream(this, queue, oneAtATime, defaultRDD) registerInputStream(inputStream) @@ -308,7 +402,7 @@ class StreamingContext private ( } /** - * Registers an input stream that will be started (InputDStream.start() called) to get the + * Register an input stream that will be started (InputDStream.start() called) to get the * input data. */ def registerInputStream(inputStream: InputDStream[_]) { @@ -316,7 +410,7 @@ class StreamingContext private ( } /** - * Registers an output stream that will be computed every interval + * Register an output stream that will be computed every interval */ def registerOutputStream(outputStream: DStream[_]) { graph.addOutputStream(outputStream) @@ -334,7 +428,7 @@ class StreamingContext private ( } /** - * Starts the execution of the streams. + * Start the execution of the streams. */ def start() { if (checkpointDir != null && checkpointDuration == null && graph != null) { @@ -362,7 +456,7 @@ class StreamingContext private ( } /** - * Sstops the execution of the streams. + * Stop the execution of the streams. */ def stop() { try { @@ -384,14 +478,14 @@ object StreamingContext { new PairDStreamFunctions[K, V](stream) } - protected[streaming] def createNewSparkContext(master: String, frameworkName: String): SparkContext = { + protected[streaming] def createNewSparkContext(master: String, appName: String): SparkContext = { // Set the default cleaner delay to an hour if not already set. // This should be sufficient for even 1 second interval. if (MetadataCleaner.getDelaySeconds < 0) { MetadataCleaner.setDelaySeconds(3600) } - new SparkContext(master, frameworkName) + new SparkContext(master, appName) } protected[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { @@ -408,4 +502,3 @@ object StreamingContext { new Path(sscCheckpointDir, UUID.randomUUID.toString).toString } } - diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index 5daeb761ddd670b6daa50ad5226ffb07496e5218..f14decf08ba8f98bc4f59e062b0b8ea17d7ade1e 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -37,6 +37,19 @@ case class Time(private val millis: Long) { def max(that: Time): Time = if (this > that) this else that + def until(that: Time, interval: Duration): Seq[Time] = { + (this.milliseconds) until (that.milliseconds) by (interval.milliseconds) map (new Time(_)) + } + + def to(that: Time, interval: Duration): Seq[Time] = { + (this.milliseconds) to (that.milliseconds) by (interval.milliseconds) map (new Time(_)) + } + + override def toString: String = (millis.toString + " ms") +} + +object Time { + val ordering = Ordering.by((time: Time) => time.millis) } \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala index 2e7466b16c93edb162ac1027a0a0e75dfff7b850..4d93f0a5f729e48cbda0c18f6104449ccb48abc7 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStream.scala @@ -4,6 +4,7 @@ import spark.streaming.{Duration, Time, DStream} import spark.api.java.function.{Function => JFunction} import spark.api.java.JavaRDD import spark.storage.StorageLevel +import spark.RDD /** * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous @@ -16,9 +17,7 @@ import spark.storage.StorageLevel * * This class contains the basic operations available on all DStreams, such as `map`, `filter` and * `window`. In addition, [[spark.streaming.api.java.JavaPairDStream]] contains operations available - * only on DStreams of key-value pairs, such as `groupByKeyAndWindow` and `join`. These operations - * are automatically available on any DStream of the right type (e.g., DStream[(Int, Int)] through - * implicit conversions when `spark.streaming.StreamingContext._` is imported. + * only on DStreams of key-value pairs, such as `groupByKeyAndWindow` and `join`. * * DStreams internally is characterized by a few basic properties: * - A list of other DStreams that the DStream depends on @@ -26,7 +25,9 @@ import spark.storage.StorageLevel * - A function that is used to generate an RDD after each time interval */ class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassManifest[T]) - extends JavaDStreamLike[T, JavaDStream[T]] { + extends JavaDStreamLike[T, JavaDStream[T], JavaRDD[T]] { + + override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd) /** Return a new DStream containing only the elements that satisfy a predicate. */ def filter(f: JFunction[T, java.lang.Boolean]): JavaDStream[T] = @@ -36,7 +37,7 @@ class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassM def cache(): JavaDStream[T] = dstream.cache() /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ - def persist(): JavaDStream[T] = dstream.cache() + def persist(): JavaDStream[T] = dstream.persist() /** Persist the RDDs of this DStream with the given storage level */ def persist(storageLevel: StorageLevel): JavaDStream[T] = dstream.persist(storageLevel) @@ -50,33 +51,26 @@ class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassM } /** - * Return a new DStream which is computed based on windowed batches of this DStream. - * The new DStream generates RDDs with the same interval as this DStream. + * Return a new DStream in which each RDD contains all the elements in seen in a + * sliding window of time over this DStream. The new DStream generates RDDs with + * the same interval as this DStream. * @param windowDuration width of the window; must be a multiple of this DStream's interval. - * @return */ def window(windowDuration: Duration): JavaDStream[T] = dstream.window(windowDuration) /** - * Return a new DStream which is computed based on windowed batches of this DStream. - * @param windowDuration duration (i.e., width) of the window; - * must be a multiple of this DStream's interval + * Return a new DStream in which each RDD contains all the elements in seen in a + * sliding window of time over this DStream. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's interval + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval */ def window(windowDuration: Duration, slideDuration: Duration): JavaDStream[T] = dstream.window(windowDuration, slideDuration) - /** - * Return a new DStream which computed based on tumbling window on this DStream. - * This is equivalent to window(batchDuration, batchDuration). - * @param batchDuration tumbling window duration; must be a multiple of this DStream's interval - */ - def tumble(batchDuration: Duration): JavaDStream[T] = - dstream.tumble(batchDuration) - /** * Return a new DStream by unifying data of another DStream with this DStream. * @param that Another DStream having the same interval (i.e., slideDuration) as this DStream. diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala index b93cb7865aef03af96e1bb9547c99786e4980198..548809a359644d21785bc4f8c96951a42f1c07f3 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaDStreamLike.scala @@ -6,17 +6,20 @@ import java.lang.{Long => JLong} import scala.collection.JavaConversions._ import spark.streaming._ -import spark.api.java.JavaRDD +import spark.api.java.{JavaPairRDD, JavaRDDLike, JavaRDD} import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} import java.util import spark.RDD import JavaDStream._ -trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable { +trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]] + extends Serializable { implicit val classManifest: ClassManifest[T] def dstream: DStream[T] + def wrapRDD(in: RDD[T]): R + implicit def scalaIntToJavaLong(in: DStream[Long]): JavaDStream[JLong] = { in.map(new JLong(_)) } @@ -33,6 +36,26 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable */ def count(): JavaDStream[JLong] = dstream.count() + /** + * Return a new DStream in which each RDD contains the counts of each distinct value in + * each RDD of this DStream. Hash partitioning is used to generate the RDDs with + * Spark's default number of partitions. + */ + def countByValue(): JavaPairDStream[T, JLong] = { + JavaPairDStream.scalaToJavaLong(dstream.countByValue()) + } + + /** + * Return a new DStream in which each RDD contains the counts of each distinct value in + * each RDD of this DStream. Hash partitioning is used to generate the RDDs with `numPartitions` + * partitions. + * @param numPartitions number of partitions of each RDD in the new DStream. + */ + def countByValue(numPartitions: Int): JavaPairDStream[T, JLong] = { + JavaPairDStream.scalaToJavaLong(dstream.countByValue(numPartitions)) + } + + /** * Return a new DStream in which each RDD has a single element generated by counting the number * of elements in a window over this DStream. windowDuration and slideDuration are as defined in the @@ -42,6 +65,39 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable dstream.countByWindow(windowDuration, slideDuration) } + /** + * Return a new DStream in which each RDD contains the count of distinct elements in + * RDDs in a sliding window over this DStream. Hash partitioning is used to generate the RDDs with + * Spark's default number of partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ + def countByValueAndWindow(windowDuration: Duration, slideDuration: Duration) + : JavaPairDStream[T, JLong] = { + JavaPairDStream.scalaToJavaLong( + dstream.countByValueAndWindow(windowDuration, slideDuration)) + } + + /** + * Return a new DStream in which each RDD contains the count of distinct elements in + * RDDs in a sliding window over this DStream. Hash partitioning is used to generate the RDDs with `numPartitions` + * partitions. + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + * @param numPartitions number of partitions of each RDD in the new DStream. + */ + def countByValueAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) + : JavaPairDStream[T, JLong] = { + JavaPairDStream.scalaToJavaLong( + dstream.countByValueAndWindow(windowDuration, slideDuration, numPartitions)) + } + /** * Return a new DStream in which each RDD is generated by applying glom() to each RDD of * this DStream. Applying glom() to an RDD coalesces all elements within each partition into @@ -59,8 +115,8 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable } /** Return a new DStream by applying a function to all elements of this DStream. */ - def map[K, V](f: PairFunction[T, K, V]): JavaPairDStream[K, V] = { - def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] + def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairDStream[K2, V2] = { + def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]] new JavaPairDStream(dstream.map(f)(cm))(f.keyType(), f.valueType()) } @@ -78,10 +134,10 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable * Return a new DStream by applying a function to all elements of this DStream, * and then flattening the results */ - def flatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairDStream[K, V] = { + def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairDStream[K2, V2] = { import scala.collection.JavaConverters._ def fn = (x: T) => f.apply(x).asScala - def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] + def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]] new JavaPairDStream(dstream.flatMap(fn)(cm))(f.keyType(), f.valueType()) } @@ -100,8 +156,8 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition * of the RDD. */ - def mapPartitions[K, V](f: PairFlatMapFunction[java.util.Iterator[T], K, V]) - : JavaPairDStream[K, V] = { + def mapPartitions[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]) + : JavaPairDStream[K2, V2] = { def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator()) new JavaPairDStream(dstream.mapPartitions(fn))(f.keyType(), f.valueType()) } @@ -114,8 +170,38 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable /** * Return a new DStream in which each RDD has a single element generated by reducing all - * elements in a window over this DStream. windowDuration and slideDuration are as defined in the - * window() operation. This is equivalent to window(windowDuration, slideDuration).reduce(reduceFunc) + * elements in a sliding window over this DStream. + * @param reduceFunc associative reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval + */ + def reduceByWindow( + reduceFunc: (T, T) => T, + windowDuration: Duration, + slideDuration: Duration + ): DStream[T] = { + dstream.reduceByWindow(reduceFunc, windowDuration, slideDuration) + } + + + /** + * Return a new DStream in which each RDD has a single element generated by reducing all + * elements in a sliding window over this DStream. However, the reduction is done incrementally + * using the old window's reduced value : + * 1. reduce the new values that entered the window (e.g., adding new counts) + * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + * This is more efficient than reduceByWindow without "inverse reduce" function. + * However, it is applicable to only "invertible reduce functions". + * @param reduceFunc associative reduce function + * @param invReduceFunc inverse reduce function + * @param windowDuration width of the window; must be a multiple of this DStream's + * batching interval + * @param slideDuration sliding interval of the window (i.e., the interval after which + * the new DStream will generate RDDs); must be a multiple of this + * DStream's batching interval */ def reduceByWindow( reduceFunc: JFunction2[T, T, T], @@ -129,35 +215,35 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable /** * Return all the RDDs between 'fromDuration' to 'toDuration' (both included) */ - def slice(fromDuration: Duration, toDuration: Duration): JList[JavaRDD[T]] = { - new util.ArrayList(dstream.slice(fromDuration, toDuration).map(new JavaRDD(_)).toSeq) + def slice(fromTime: Time, toTime: Time): JList[R] = { + new util.ArrayList(dstream.slice(fromTime, toTime).map(wrapRDD(_)).toSeq) } /** * Apply a function to each RDD in this DStream. This is an output operator, so * this DStream will be registered as an output stream and therefore materialized. */ - def foreach(foreachFunc: JFunction[JavaRDD[T], Void]) { - dstream.foreach(rdd => foreachFunc.call(new JavaRDD(rdd))) + def foreach(foreachFunc: JFunction[R, Void]) { + dstream.foreach(rdd => foreachFunc.call(wrapRDD(rdd))) } /** * Apply a function to each RDD in this DStream. This is an output operator, so * this DStream will be registered as an output stream and therefore materialized. */ - def foreach(foreachFunc: JFunction2[JavaRDD[T], Time, Void]) { - dstream.foreach((rdd, time) => foreachFunc.call(new JavaRDD(rdd), time)) + def foreach(foreachFunc: JFunction2[R, Time, Void]) { + dstream.foreach((rdd, time) => foreachFunc.call(wrapRDD(rdd), time)) } /** * Return a new DStream in which each RDD is generated by applying a function * on each RDD of this DStream. */ - def transform[U](transformFunc: JFunction[JavaRDD[T], JavaRDD[U]]): JavaDStream[U] = { + def transform[U](transformFunc: JFunction[R, JavaRDD[U]]): JavaDStream[U] = { implicit val cm: ClassManifest[U] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] def scalaTransform (in: RDD[T]): RDD[U] = - transformFunc.call(new JavaRDD[T](in)).rdd + transformFunc.call(wrapRDD(in)).rdd dstream.transform(scalaTransform(_)) } @@ -165,11 +251,41 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This]] extends Serializable * Return a new DStream in which each RDD is generated by applying a function * on each RDD of this DStream. */ - def transform[U](transformFunc: JFunction2[JavaRDD[T], Time, JavaRDD[U]]): JavaDStream[U] = { + def transform[U](transformFunc: JFunction2[R, Time, JavaRDD[U]]): JavaDStream[U] = { implicit val cm: ClassManifest[U] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] def scalaTransform (in: RDD[T], time: Time): RDD[U] = - transformFunc.call(new JavaRDD[T](in), time).rdd + transformFunc.call(wrapRDD(in), time).rdd + dstream.transform(scalaTransform(_, _)) + } + + /** + * Return a new DStream in which each RDD is generated by applying a function + * on each RDD of this DStream. + */ + def transform[K2, V2](transformFunc: JFunction[R, JavaPairRDD[K2, V2]]): + JavaPairDStream[K2, V2] = { + implicit val cmk: ClassManifest[K2] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K2]] + implicit val cmv: ClassManifest[V2] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V2]] + def scalaTransform (in: RDD[T]): RDD[(K2, V2)] = + transformFunc.call(wrapRDD(in)).rdd + dstream.transform(scalaTransform(_)) + } + + /** + * Return a new DStream in which each RDD is generated by applying a function + * on each RDD of this DStream. + */ + def transform[K2, V2](transformFunc: JFunction2[R, Time, JavaPairRDD[K2, V2]]): + JavaPairDStream[K2, V2] = { + implicit val cmk: ClassManifest[K2] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K2]] + implicit val cmv: ClassManifest[V2] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V2]] + def scalaTransform (in: RDD[T], time: Time): RDD[(K2, V2)] = + transformFunc.call(wrapRDD(in), time).rdd dstream.transform(scalaTransform(_, _)) } diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala index ef10c091caaae66b352c0a8d24b86fe14240e215..30240cad988be32e82c701a0dd7b535156c3dcb4 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -8,34 +8,37 @@ import scala.collection.JavaConversions._ import spark.streaming._ import spark.streaming.StreamingContext._ import spark.api.java.function.{Function => JFunction, Function2 => JFunction2} -import spark.Partitioner +import spark.{RDD, Partitioner} import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} import org.apache.hadoop.conf.Configuration -import spark.api.java.JavaPairRDD +import spark.api.java.{JavaRDD, JavaPairRDD} import spark.storage.StorageLevel import com.google.common.base.Optional +import spark.RDD class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( implicit val kManifiest: ClassManifest[K], implicit val vManifest: ClassManifest[V]) - extends JavaDStreamLike[(K, V), JavaPairDStream[K, V]] { + extends JavaDStreamLike[(K, V), JavaPairDStream[K, V], JavaPairRDD[K, V]] { + + override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd) // ======================================================================= // Methods common to all DStream's // ======================================================================= - /** Returns a new DStream containing only the elements that satisfy a predicate. */ + /** Return a new DStream containing only the elements that satisfy a predicate. */ def filter(f: JFunction[(K, V), java.lang.Boolean]): JavaPairDStream[K, V] = dstream.filter((x => f(x).booleanValue())) - /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ + /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ def cache(): JavaPairDStream[K, V] = dstream.cache() - /** Persists RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ - def persist(): JavaPairDStream[K, V] = dstream.cache() + /** Persist RDDs of this DStream with the default storage level (MEMORY_ONLY_SER) */ + def persist(): JavaPairDStream[K, V] = dstream.persist() - /** Persists the RDDs of this DStream with the given storage level */ + /** Persist the RDDs of this DStream with the given storage level */ def persist(storageLevel: StorageLevel): JavaPairDStream[K, V] = dstream.persist(storageLevel) /** Method that generates a RDD for the given Duration */ @@ -67,15 +70,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.window(windowDuration, slideDuration) /** - * Returns a new DStream which computed based on tumbling window on this DStream. - * This is equivalent to window(batchDuration, batchDuration). - * @param batchDuration tumbling window duration; must be a multiple of this DStream's interval - */ - def tumble(batchDuration: Duration): JavaPairDStream[K, V] = - dstream.tumble(batchDuration) - - /** - * Returns a new DStream by unifying data of another DStream with this DStream. + * Return a new DStream by unifying data of another DStream with this DStream. * @param that Another DStream having the same interval (i.e., slideDuration) as this DStream. */ def union(that: JavaPairDStream[K, V]): JavaPairDStream[K, V] = @@ -86,21 +81,21 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( // ======================================================================= /** - * Create a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to + * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with Spark's default number of partitions. */ def groupByKey(): JavaPairDStream[K, JList[V]] = dstream.groupByKey().mapValues(seqAsJavaList _) /** - * Create a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to + * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with `numPartitions` partitions. */ def groupByKey(numPartitions: Int): JavaPairDStream[K, JList[V]] = dstream.groupByKey(numPartitions).mapValues(seqAsJavaList _) /** - * Creates a new DStream by applying `groupByKey` on each RDD of `this` DStream. + * Return a new DStream by applying `groupByKey` on each RDD of `this` DStream. * Therefore, the values for each key in `this` DStream's RDDs are grouped into a * single sequence to generate the RDDs of the new DStream. [[spark.Partitioner]] * is used to control the partitioning of each RDD. @@ -109,7 +104,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.groupByKey(partitioner).mapValues(seqAsJavaList _) /** - * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are * merged using the associative reduce function. Hash partitioning is used to generate the RDDs * with Spark's default number of partitions. */ @@ -117,7 +112,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.reduceByKey(func) /** - * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are * merged using the supplied reduce function. Hash partitioning is used to generate the RDDs * with `numPartitions` partitions. */ @@ -125,7 +120,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.reduceByKey(func, numPartitions) /** - * Create a new DStream by applying `reduceByKey` to each RDD. The values for each key are + * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are * merged using the supplied reduce function. [[spark.Partitioner]] is used to control the * partitioning of each RDD. */ @@ -149,24 +144,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Create a new DStream by counting the number of values of each key in each RDD. Hash - * partitioning is used to generate the RDDs with Spark's `numPartitions` partitions. - */ - def countByKey(numPartitions: Int): JavaPairDStream[K, JLong] = { - JavaPairDStream.scalaToJavaLong(dstream.countByKey(numPartitions)); - } - - - /** - * Create a new DStream by counting the number of values of each key in each RDD. Hash - * partitioning is used to generate the RDDs with the default number of partitions. - */ - def countByKey(): JavaPairDStream[K, JLong] = { - JavaPairDStream.scalaToJavaLong(dstream.countByKey()); - } - - /** - * Creates a new DStream by applying `groupByKey` over a sliding window. This is similar to + * Return a new DStream by applying `groupByKey` over a sliding window. This is similar to * `DStream.groupByKey()` but applies it over a sliding window. The new DStream generates RDDs * with the same interval as this DStream. Hash partitioning is used to generate the RDDs with * Spark's default number of partitions. @@ -178,7 +156,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Create a new DStream by applying `groupByKey` over a sliding window. Similar to + * Return a new DStream by applying `groupByKey` over a sliding window. Similar to * `DStream.groupByKey()`, but applies it over a sliding window. Hash partitioning is used to * generate the RDDs with Spark's default number of partitions. * @param windowDuration width of the window; must be a multiple of this DStream's @@ -193,7 +171,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Create a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * Return a new DStream by applying `groupByKey` over a sliding window on `this` DStream. * Similar to `DStream.groupByKey()`, but applies it over a sliding window. * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. * @param windowDuration width of the window; must be a multiple of this DStream's @@ -210,7 +188,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Create a new DStream by applying `groupByKey` over a sliding window on `this` DStream. + * Return a new DStream by applying `groupByKey` over a sliding window on `this` DStream. * Similar to `DStream.groupByKey()`, but applies it over a sliding window. * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval @@ -243,7 +221,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Create a new DStream by applying `reduceByKey` over a sliding window. This is similar to + * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to * generate the RDDs with Spark's default number of partitions. * @param reduceFunc associative reduce function @@ -262,7 +240,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Create a new DStream by applying `reduceByKey` over a sliding window. This is similar to + * Return a new DStream by applying `reduceByKey` over a sliding window. This is similar to * `DStream.reduceByKey()` but applies it over a sliding window. Hash partitioning is used to * generate the RDDs with `numPartitions` partitions. * @param reduceFunc associative reduce function @@ -283,7 +261,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Create a new DStream by applying `reduceByKey` over a sliding window. Similar to + * Return a new DStream by applying `reduceByKey` over a sliding window. Similar to * `DStream.reduceByKey()`, but applies it over a sliding window. * @param reduceFunc associative reduce function * @param windowDuration width of the window; must be a multiple of this DStream's @@ -303,7 +281,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Create a new DStream by reducing over a using incremental computation. + * Return a new DStream by reducing over a using incremental computation. * The reduced value of over a new window is calculated using the old window's reduce value : * 1. reduce the new values that entered the window (e.g., adding new counts) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) @@ -328,7 +306,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( } /** - * Create a new DStream by reducing over a using incremental computation. + * Return a new DStream by applying incremental `reduceByKey` over a sliding window. * The reduced value of over a new window is calculated using the old window's reduce value : * 1. reduce the new values that entered the window (e.g., adding new counts) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) @@ -342,25 +320,31 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * @param slideDuration sliding interval of the window (i.e., the interval after which * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval - * @param numPartitions Number of partitions of each RDD in the new DStream. + * @param numPartitions number of partitions of each RDD in the new DStream. + * @param filterFunc function to filter expired key-value pairs; + * only pairs that satisfy the function are retained + * set this to null if you do not want to filter */ def reduceByKeyAndWindow( reduceFunc: Function2[V, V, V], invReduceFunc: Function2[V, V, V], windowDuration: Duration, slideDuration: Duration, - numPartitions: Int + numPartitions: Int, + filterFunc: JFunction[(K, V), java.lang.Boolean] ): JavaPairDStream[K, V] = { dstream.reduceByKeyAndWindow( reduceFunc, invReduceFunc, windowDuration, slideDuration, - numPartitions) + numPartitions, + (p: (K, V)) => filterFunc(p).booleanValue() + ) } /** - * Create a new DStream by reducing over a using incremental computation. + * Return a new DStream by applying incremental `reduceByKey` over a sliding window. * The reduced value of over a new window is calculated using the old window's reduce value : * 1. reduce the new values that entered the window (e.g., adding new counts) * 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) @@ -374,49 +358,26 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + * @param filterFunc function to filter expired key-value pairs; + * only pairs that satisfy the function are retained + * set this to null if you do not want to filter */ def reduceByKeyAndWindow( reduceFunc: Function2[V, V, V], invReduceFunc: Function2[V, V, V], windowDuration: Duration, slideDuration: Duration, - partitioner: Partitioner - ): JavaPairDStream[K, V] = { + partitioner: Partitioner, + filterFunc: JFunction[(K, V), java.lang.Boolean] + ): JavaPairDStream[K, V] = { dstream.reduceByKeyAndWindow( reduceFunc, invReduceFunc, windowDuration, slideDuration, - partitioner) - } - - /** - * Create a new DStream by counting the number of values for each key over a window. - * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - */ - def countByKeyAndWindow(windowDuration: Duration, slideDuration: Duration) - : JavaPairDStream[K, JLong] = { - JavaPairDStream.scalaToJavaLong(dstream.countByKeyAndWindow(windowDuration, slideDuration)) - } - - /** - * Create a new DStream by counting the number of values for each key over a window. - * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. - * @param windowDuration width of the window; must be a multiple of this DStream's - * batching interval - * @param slideDuration sliding interval of the window (i.e., the interval after which - * the new DStream will generate RDDs); must be a multiple of this - * DStream's batching interval - * @param numPartitions Number of partitions of each RDD in the new DStream. - */ - def countByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) - : JavaPairDStream[K, Long] = { - dstream.countByKeyAndWindow(windowDuration, slideDuration, numPartitions) + partitioner, + (p: (K, V)) => filterFunc(p).booleanValue() + ) } private def convertUpdateStateFunction[S](in: JFunction2[JList[V], Optional[S], Optional[S]]): diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala index e7f446a49b581ba249ae84c7436dc7b417a9f1b9..f3b40b5b88ef9dd80f2f717ccdd2bf151ddd6f00 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaStreamingContext.scala @@ -1,16 +1,26 @@ package spark.streaming.api.java -import scala.collection.JavaConversions._ -import java.lang.{Long => JLong, Integer => JInt} - import spark.streaming._ -import dstream._ +import receivers.{ActorReceiver, ReceiverSupervisorStrategy} +import spark.streaming.dstream._ import spark.storage.StorageLevel + import spark.api.java.function.{Function => JFunction, Function2 => JFunction2} +import spark.api.java.{JavaSparkContext, JavaRDD} + import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} + +import twitter4j.Status + +import akka.actor.Props +import akka.actor.SupervisorStrategy +import akka.zeromq.Subscribe + +import scala.collection.JavaConversions._ + +import java.lang.{Long => JLong, Integer => JInt} import java.io.InputStream import java.util.{Map => JMap} -import spark.api.java.{JavaSparkContext, JavaRDD} /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -27,11 +37,11 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Creates a StreamingContext. * @param master Name of the Spark Master - * @param frameworkName Name to be used when registering with the scheduler + * @param appName Name to be used when registering with the scheduler * @param batchDuration The time interval at which streaming data will be divided into batches */ - def this(master: String, frameworkName: String, batchDuration: Duration) = - this(new StreamingContext(master, frameworkName, batchDuration)) + def this(master: String, appName: String, batchDuration: Duration) = + this(new StreamingContext(master, appName, batchDuration)) /** * Creates a StreamingContext. @@ -53,27 +63,24 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Create an input stream that pulls messages form a Kafka Broker. - * @param hostname Zookeper hostname. - * @param port Zookeper port. + * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. */ def kafkaStream[T]( - hostname: String, - port: Int, + zkQuorum: String, groupId: String, topics: JMap[String, JInt]) : JavaDStream[T] = { implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - ssc.kafkaStream[T](hostname, port, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*)) + ssc.kafkaStream[T](zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*)) } /** * Create an input stream that pulls messages form a Kafka Broker. - * @param hostname Zookeper hostname. - * @param port Zookeper port. + * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. @@ -81,8 +88,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { * By default the value is pulled from zookeper. */ def kafkaStream[T]( - hostname: String, - port: Int, + zkQuorum: String, groupId: String, topics: JMap[String, JInt], initialOffsets: JMap[KafkaPartitionKey, JLong]) @@ -90,8 +96,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] ssc.kafkaStream[T]( - hostname, - port, + zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), Map(initialOffsets.mapValues(_.longValue()).toSeq: _*)) @@ -99,8 +104,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Create an input stream that pulls messages form a Kafka Broker. - * @param hostname Zookeper hostname. - * @param port Zookeper port. + * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. @@ -109,8 +113,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param storageLevel RDD storage level. Defaults to memory-only */ def kafkaStream[T]( - hostname: String, - port: Int, + zkQuorum: String, groupId: String, topics: JMap[String, JInt], initialOffsets: JMap[KafkaPartitionKey, JLong], @@ -119,8 +122,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] ssc.kafkaStream[T]( - hostname, - port, + zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*), Map(initialOffsets.mapValues(_.longValue()).toSeq: _*), @@ -136,9 +138,9 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param storageLevel Storage level to use for storing the received objects * (default: StorageLevel.MEMORY_AND_DISK_SER_2) */ - def networkTextStream(hostname: String, port: Int, storageLevel: StorageLevel) + def socketTextStream(hostname: String, port: Int, storageLevel: StorageLevel) : JavaDStream[String] = { - ssc.networkTextStream(hostname, port, storageLevel) + ssc.socketTextStream(hostname, port, storageLevel) } /** @@ -148,8 +150,8 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param hostname Hostname to connect to for receiving data * @param port Port to connect to for receiving data */ - def networkTextStream(hostname: String, port: Int): JavaDStream[String] = { - ssc.networkTextStream(hostname, port) + def socketTextStream(hostname: String, port: Int): JavaDStream[String] = { + ssc.socketTextStream(hostname, port) } /** @@ -162,7 +164,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param storageLevel Storage level to use for storing the received objects * @tparam T Type of the objects received (after converting bytes to objects) */ - def networkStream[T]( + def socketStream[T]( hostname: String, port: Int, converter: JFunction[InputStream, java.lang.Iterable[T]], @@ -171,7 +173,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { def fn = (x: InputStream) => converter.apply(x).toIterator implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - ssc.networkStream(hostname, port, fn, storageLevel) + ssc.socketStream(hostname, port, fn, storageLevel) } /** @@ -194,13 +196,13 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param storageLevel Storage level to use for storing the received objects * @tparam T Type of the objects in the received blocks */ - def rawNetworkStream[T]( + def rawSocketStream[T]( hostname: String, port: Int, storageLevel: StorageLevel): JavaDStream[T] = { implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - JavaDStream.fromDStream(ssc.rawNetworkStream(hostname, port, storageLevel)) + JavaDStream.fromDStream(ssc.rawSocketStream(hostname, port, storageLevel)) } /** @@ -212,10 +214,10 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param port Port to connect to for receiving data * @tparam T Type of the objects in the received blocks */ - def rawNetworkStream[T](hostname: String, port: Int): JavaDStream[T] = { + def rawSocketStream[T](hostname: String, port: Int): JavaDStream[T] = { implicit val cmt: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - JavaDStream.fromDStream(ssc.rawNetworkStream(hostname, port)) + JavaDStream.fromDStream(ssc.rawSocketStream(hostname, port)) } /** @@ -254,15 +256,182 @@ class JavaStreamingContext(val ssc: StreamingContext) { * @param hostname Hostname of the slave machine to which the flume data will be sent * @param port Port of the slave machine to which the flume data will be sent */ - def flumeStream(hostname: String, port: Int): - JavaDStream[SparkFlumeEvent] = { + def flumeStream(hostname: String, port: Int): JavaDStream[SparkFlumeEvent] = { ssc.flumeStream(hostname, port) } + /** + * Create a input stream that returns tweets received from Twitter. + * @param username Twitter username + * @param password Twitter password + * @param filters Set of filter strings to get only those tweets that match them + * @param storageLevel Storage level to use for storing the received objects + */ + def twitterStream( + username: String, + password: String, + filters: Array[String], + storageLevel: StorageLevel + ): JavaDStream[Status] = { + ssc.twitterStream(username, password, filters, storageLevel) + } + + /** + * Create a input stream that returns tweets received from Twitter. + * @param username Twitter username + * @param password Twitter password + * @param filters Set of filter strings to get only those tweets that match them + */ + def twitterStream( + username: String, + password: String, + filters: Array[String] + ): JavaDStream[Status] = { + ssc.twitterStream(username, password, filters) + } + + /** + * Create a input stream that returns tweets received from Twitter. + * @param username Twitter username + * @param password Twitter password + */ + def twitterStream( + username: String, + password: String + ): JavaDStream[Status] = { + ssc.twitterStream(username, password) + } + + /** + * Create an input stream with any arbitrary user implemented actor receiver. + * @param props Props object defining creation of the actor + * @param name Name of the actor + * @param storageLevel Storage level to use for storing the received objects + * + * @note An important point to note: + * Since Actor may exist outside the spark framework, It is thus user's responsibility + * to ensure the type safety, i.e parametrized type of data received and actorStream + * should be same. + */ + def actorStream[T]( + props: Props, + name: String, + storageLevel: StorageLevel, + supervisorStrategy: SupervisorStrategy + ): JavaDStream[T] = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + ssc.actorStream[T](props, name, storageLevel, supervisorStrategy) + } + + /** + * Create an input stream with any arbitrary user implemented actor receiver. + * @param props Props object defining creation of the actor + * @param name Name of the actor + * @param storageLevel Storage level to use for storing the received objects + * + * @note An important point to note: + * Since Actor may exist outside the spark framework, It is thus user's responsibility + * to ensure the type safety, i.e parametrized type of data received and actorStream + * should be same. + */ + def actorStream[T]( + props: Props, + name: String, + storageLevel: StorageLevel + ): JavaDStream[T] = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + ssc.actorStream[T](props, name, storageLevel) + } + + /** + * Create an input stream with any arbitrary user implemented actor receiver. + * @param props Props object defining creation of the actor + * @param name Name of the actor + * + * @note An important point to note: + * Since Actor may exist outside the spark framework, It is thus user's responsibility + * to ensure the type safety, i.e parametrized type of data received and actorStream + * should be same. + */ + def actorStream[T]( + props: Props, + name: String + ): JavaDStream[T] = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + ssc.actorStream[T](props, name) + } + + /** + * Create an input stream that receives messages pushed by a zeromq publisher. + * @param publisherUrl Url of remote zeromq publisher + * @param subscribe topic to subscribe to + * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each frame has sequence + * of byte thus it needs the converter(which might be deserializer of bytes) + * to translate from sequence of sequence of bytes, where sequence refer to a frame + * and sub sequence refer to its payload. + * @param storageLevel Storage level to use for storing the received objects + */ + def zeroMQStream[T]( + publisherUrl:String, + subscribe: Subscribe, + bytesToObjects: Seq[Seq[Byte]] ⇒ Iterator[T], + storageLevel: StorageLevel, + supervisorStrategy: SupervisorStrategy + ): JavaDStream[T] = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + ssc.zeroMQStream[T](publisherUrl, subscribe, bytesToObjects, storageLevel, supervisorStrategy) + } + + /** + * Create an input stream that receives messages pushed by a zeromq publisher. + * @param publisherUrl Url of remote zeromq publisher + * @param subscribe topic to subscribe to + * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each frame has sequence + * of byte thus it needs the converter(which might be deserializer of bytes) + * to translate from sequence of sequence of bytes, where sequence refer to a frame + * and sub sequence refer to its payload. + * @param storageLevel RDD storage level. Defaults to memory-only. + */ + def zeroMQStream[T]( + publisherUrl:String, + subscribe: Subscribe, + bytesToObjects: JFunction[Array[Array[Byte]], java.lang.Iterable[T]], + storageLevel: StorageLevel + ): JavaDStream[T] = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + def fn(x: Seq[Seq[Byte]]) = bytesToObjects.apply(x.map(_.toArray).toArray).toIterator + ssc.zeroMQStream[T](publisherUrl, subscribe, fn, storageLevel) + } + + /** + * Create an input stream that receives messages pushed by a zeromq publisher. + * @param publisherUrl Url of remote zeromq publisher + * @param subscribe topic to subscribe to + * @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each frame has sequence + * of byte thus it needs the converter(which might be deserializer of bytes) + * to translate from sequence of sequence of bytes, where sequence refer to a frame + * and sub sequence refer to its payload. + */ + def zeroMQStream[T]( + publisherUrl:String, + subscribe: Subscribe, + bytesToObjects: JFunction[Array[Array[Byte]], java.lang.Iterable[T]] + ): JavaDStream[T] = { + implicit val cm: ClassManifest[T] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + def fn(x: Seq[Seq[Byte]]) = bytesToObjects.apply(x.map(_.toArray).toArray).toIterator + ssc.zeroMQStream[T](publisherUrl, subscribe, fn) + } + /** * Registers an output stream that will be computed every interval */ - def registerOutputStream(outputStream: JavaDStreamLike[_, _]) { + def registerOutputStream(outputStream: JavaDStreamLike[_, _, _]) { ssc.registerOutputStream(outputStream.dstream) } @@ -322,12 +491,11 @@ class JavaStreamingContext(val ssc: StreamingContext) { /** * Sets the context to periodically checkpoint the DStream operations for master - * fault-tolerance. By default, the graph will be checkpointed every batch interval. + * fault-tolerance. The graph will be checkpointed every batch interval. * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored - * @param interval checkpoint interval */ - def checkpoint(directory: String, interval: Duration = null) { - ssc.checkpoint(directory, interval) + def checkpoint(directory: String) { + ssc.checkpoint(directory) } /** diff --git a/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala index ddb1bf6b28fdb4466dacf1f6b827e8086dc9090c..4ef4bb7de1023f2a3159cccb58d70aeb56a20bf5 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/CoGroupedDStream.scala @@ -6,7 +6,7 @@ import spark.streaming.{Time, DStream, Duration} private[streaming] class CoGroupedDStream[K : ClassManifest]( - parents: Seq[DStream[(_, _)]], + parents: Seq[DStream[(K, _)]], partitioner: Partitioner ) extends DStream[(K, Seq[Seq[_]])](parents.head.ssc) { diff --git a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala index 1e6ad84b44014c4d1c063b3b1effed934d532c87..41b9bd9461d9cc5ec93efe43bebdad62de0d9223 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/FileInputDStream.scala @@ -2,13 +2,14 @@ package spark.streaming.dstream import spark.RDD import spark.rdd.UnionRDD -import spark.streaming.{StreamingContext, Time} +import spark.streaming.{DStreamCheckpointData, StreamingContext, Time} import org.apache.hadoop.fs.{FileSystem, Path, PathFilter} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import scala.collection.mutable.HashSet +import scala.collection.mutable.{HashSet, HashMap} +import java.io.{ObjectInputStream, IOException} private[streaming] class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K,V] : ClassManifest]( @@ -18,28 +19,23 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K newFilesOnly: Boolean = true) extends InputDStream[(K, V)](ssc_) { - @transient private var path_ : Path = null - @transient private var fs_ : FileSystem = null - - var lastModTime = 0L - val lastModTimeFiles = new HashSet[String]() + protected[streaming] override val checkpointData = new FileInputDStreamCheckpointData - def path(): Path = { - if (path_ == null) path_ = new Path(directory) - path_ - } + // Latest file mod time seen till any point of time + private val lastModTimeFiles = new HashSet[String]() + private var lastModTime = 0L - def fs(): FileSystem = { - if (fs_ == null) fs_ = path.getFileSystem(new Configuration()) - fs_ - } + @transient private var path_ : Path = null + @transient private var fs_ : FileSystem = null + @transient private[streaming] var files = new HashMap[Time, Array[String]] override def start() { if (newFilesOnly) { - lastModTime = System.currentTimeMillis() + lastModTime = graph.zeroTime.milliseconds } else { lastModTime = 0 } + logDebug("LastModTime initialized to " + lastModTime + ", new files only = " + newFilesOnly) } override def stop() { } @@ -49,38 +45,50 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K * a union RDD out of them. Note that this maintains the list of files that were processed * in the latest modification time in the previous call to this method. This is because the * modification time returned by the FileStatus API seems to return times only at the - * granularity of seconds. Hence, new files may have the same modification time as the - * latest modification time in the previous call to this method and the list of files - * maintained is used to filter the one that have been processed. + * granularity of seconds. And new files may have the same modification time as the + * latest modification time in the previous call to this method yet was not reported in + * the previous call. */ override def compute(validTime: Time): Option[RDD[(K, V)]] = { + assert(validTime.milliseconds >= lastModTime, "Trying to get new files for really old time [" + validTime + " < " + lastModTime) + // Create the filter for selecting new files val newFilter = new PathFilter() { + // Latest file mod time seen in this round of fetching files and its corresponding files var latestModTime = 0L val latestModTimeFiles = new HashSet[String]() def accept(path: Path): Boolean = { - if (!filter(path)) { + if (!filter(path)) { // Reject file if it does not satisfy filter + logDebug("Rejected by filter " + path) return false - } else { + } else { // Accept file only if val modTime = fs.getFileStatus(path).getModificationTime() - if (modTime < lastModTime){ - return false + logDebug("Mod time for " + path + " is " + modTime) + if (modTime < lastModTime) { + logDebug("Mod time less than last mod time") + return false // If the file was created before the last time it was called } else if (modTime == lastModTime && lastModTimeFiles.contains(path.toString)) { - return false + logDebug("Mod time equal to last mod time, but file considered already") + return false // If the file was created exactly as lastModTime but not reported yet + } else if (modTime > validTime.milliseconds) { + logDebug("Mod time more than valid time") + return false // If the file was created after the time this function call requires } if (modTime > latestModTime) { latestModTime = modTime latestModTimeFiles.clear() + logDebug("Latest mod time updated to " + latestModTime) } latestModTimeFiles += path.toString + logDebug("Accepted " + path) return true } } } - - val newFiles = fs.listStatus(path, newFilter) - logInfo("New files: " + newFiles.map(_.getPath).mkString(", ")) + logDebug("Finding new files at time " + validTime + " for last mod time = " + lastModTime) + val newFiles = fs.listStatus(path, newFilter).map(_.getPath.toString) + logInfo("New files at time " + validTime + ":\n" + newFiles.mkString("\n")) if (newFiles.length > 0) { // Update the modification time and the files processed for that modification time if (lastModTime != newFilter.latestModTime) { @@ -88,10 +96,81 @@ class FileInputDStream[K: ClassManifest, V: ClassManifest, F <: NewInputFormat[K lastModTimeFiles.clear() } lastModTimeFiles ++= newFilter.latestModTimeFiles + logDebug("Last mod time updated to " + lastModTime) + } + files += ((validTime, newFiles)) + Some(filesToRDD(newFiles)) + } + + /** Clear the old time-to-files mappings along with old RDDs */ + protected[streaming] override def clearOldMetadata(time: Time) { + super.clearOldMetadata(time) + val oldFiles = files.filter(_._1 <= (time - rememberDuration)) + files --= oldFiles.keys + logInfo("Cleared " + oldFiles.size + " old files that were older than " + + (time - rememberDuration) + ": " + oldFiles.keys.mkString(", ")) + logDebug("Cleared files are:\n" + + oldFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n")) + } + + /** Generate one RDD from an array of files */ + protected[streaming] def filesToRDD(files: Seq[String]): RDD[(K, V)] = { + new UnionRDD( + context.sparkContext, + files.map(file => context.sparkContext.newAPIHadoopFile[K, V, F](file)) + ) + } + + private def path: Path = { + if (path_ == null) path_ = new Path(directory) + path_ + } + + private def fs: FileSystem = { + if (fs_ == null) fs_ = path.getFileSystem(new Configuration()) + fs_ + } + + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + logDebug(this.getClass().getSimpleName + ".readObject used") + ois.defaultReadObject() + generatedRDDs = new HashMap[Time, RDD[(K,V)]] () + files = new HashMap[Time, Array[String]] + } + + /** + * A custom version of the DStreamCheckpointData that stores names of + * Hadoop files as checkpoint data. + */ + private[streaming] + class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) { + + def hadoopFiles = data.asInstanceOf[HashMap[Time, Array[String]]] + + override def update() { + hadoopFiles.clear() + hadoopFiles ++= files + } + + override def cleanup() { } + + override def restore() { + hadoopFiles.foreach { + case (t, f) => { + // Restore the metadata in both files and generatedRDDs + logInfo("Restoring files for time " + t + " - " + + f.mkString("[", ", ", "]") ) + files += ((t, f)) + generatedRDDs += ((t, filesToRDD(f))) + } + } + } + + override def toString() = { + "[\n" + hadoopFiles.size + " file sets\n" + + hadoopFiles.map(p => (p._1, p._2.mkString(", "))).mkString("\n") + "\n]" } - val newRDD = new UnionRDD(ssc.sc, newFiles.map( - file => ssc.sc.newAPIHadoopFile[K, V, F](file.getPath.toString))) - Some(newRDD) } } @@ -100,3 +179,4 @@ object FileInputDStream { def defaultFilter(path: Path): Boolean = !path.getName().startsWith(".") } + diff --git a/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala index efc70584805697bffa744f82664ff59c9578cbf6..c9644b3a83aa053e96519cab70d5754b6b6803c3 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/FlumeInputDStream.scala @@ -25,7 +25,7 @@ class FlumeInputDStream[T: ClassManifest]( storageLevel: StorageLevel ) extends NetworkInputDStream[SparkFlumeEvent](ssc_) { - override def createReceiver(): NetworkReceiver[SparkFlumeEvent] = { + override def getReceiver(): NetworkReceiver[SparkFlumeEvent] = { new FlumeReceiver(host, port, storageLevel) } } @@ -134,4 +134,4 @@ class FlumeReceiver( } override def getLocationPreference = Some(host) -} \ No newline at end of file +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala index 980ca5177eb1d20a5b8b8a85155e914ff94e1511..3c5d43a60955a38273aabd224073c8fabf2830d4 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/InputDStream.scala @@ -1,10 +1,42 @@ package spark.streaming.dstream -import spark.streaming.{Duration, StreamingContext, DStream} +import spark.streaming.{Time, Duration, StreamingContext, DStream} +/** + * This is the abstract base class for all input streams. This class provides to methods + * start() and stop() which called by the scheduler to start and stop receiving data/ + * Input streams that can generated RDDs from new data just by running a service on + * the driver node (that is, without running a receiver onworker nodes) can be + * implemented by directly subclassing this InputDStream. For example, + * FileInputDStream, a subclass of InputDStream, monitors a HDFS directory for + * new files and generates RDDs on the new files. For implementing input streams + * that requires running a receiver on the worker nodes, use NetworkInputDStream + * as the parent class. + */ abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContext) extends DStream[T](ssc_) { + var lastValidTime: Time = null + + /** + * Checks whether the 'time' is valid wrt slideDuration for generating RDD. + * Additionally it also ensures valid times are in strictly increasing order. + * This ensures that InputDStream.compute() is called strictly on increasing + * times. + */ + override protected def isTimeValid(time: Time): Boolean = { + if (!super.isTimeValid(time)) { + false // Time not valid + } else { + // Time is valid, but check it it is more than lastValidTime + if (lastValidTime != null && time < lastValidTime) { + logWarning("isTimeValid called with " + time + " where as last valid time is " + lastValidTime) + } + lastValidTime = time + true + } + } + override def dependencies = List() override def slideDuration: Duration = { @@ -13,7 +45,9 @@ abstract class InputDStream[T: ClassManifest] (@transient ssc_ : StreamingContex ssc.graph.batchDuration } + /** Method called to start receiving data. Subclasses must implement this method. */ def start() + /** Method called to stop receiving data. Subclasses must implement this method. */ def stop() } diff --git a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala index 2b4740bdf7c7df14e7fe5cbeea1b1d89096714ae..dc7139cc273cf75be9dd85e1caaa8579dc052d7c 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/KafkaInputDStream.scala @@ -19,21 +19,11 @@ import scala.collection.JavaConversions._ // Key for a specific Kafka Partition: (broker, topic, group, part) case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int) -// NOT USED - Originally intended for fault-tolerance -// Metadata for a Kafka Stream that it sent to the Master -private[streaming] -case class KafkaInputDStreamMetadata(timestamp: Long, data: Map[KafkaPartitionKey, Long]) -// NOT USED - Originally intended for fault-tolerance -// Checkpoint data specific to a KafkaInputDstream -private[streaming] -case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any], - savedOffsets: Map[KafkaPartitionKey, Long]) extends DStreamCheckpointData(kafkaRdds) /** * Input stream that pulls messages from a Kafka Broker. * - * @param host Zookeper hostname. - * @param port Zookeper port. + * @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..). * @param groupId The group id for this consumer. * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. @@ -44,65 +34,22 @@ case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any], private[streaming] class KafkaInputDStream[T: ClassManifest]( @transient ssc_ : StreamingContext, - host: String, - port: Int, + zkQuorum: String, groupId: String, topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], storageLevel: StorageLevel ) extends NetworkInputDStream[T](ssc_ ) with Logging { - // Metadata that keeps track of which messages have already been consumed. - var savedOffsets = HashMap[Long, Map[KafkaPartitionKey, Long]]() - - /* NOT USED - Originally intended for fault-tolerance - - // In case of a failure, the offets for a particular timestamp will be restored. - @transient var restoredOffsets : Map[KafkaPartitionKey, Long] = null - - - override protected[streaming] def addMetadata(metadata: Any) { - metadata match { - case x : KafkaInputDStreamMetadata => - savedOffsets(x.timestamp) = x.data - // TOOD: Remove logging - logInfo("New saved Offsets: " + savedOffsets) - case _ => logInfo("Received unknown metadata: " + metadata.toString) - } - } - - override protected[streaming] def updateCheckpointData(currentTime: Time) { - super.updateCheckpointData(currentTime) - if(savedOffsets.size > 0) { - // Find the offets that were stored before the checkpoint was initiated - val key = savedOffsets.keys.toList.sortWith(_ < _).filter(_ < currentTime.millis).last - val latestOffsets = savedOffsets(key) - logInfo("Updating KafkaDStream checkpoint data: " + latestOffsets.toString) - checkpointData = KafkaDStreamCheckpointData(checkpointData.rdds, latestOffsets) - // TODO: This may throw out offsets that are created after the checkpoint, - // but it's unlikely we'll need them. - savedOffsets.clear() - } - } - - override protected[streaming] def restoreCheckpointData() { - super.restoreCheckpointData() - logInfo("Restoring KafkaDStream checkpoint data.") - checkpointData match { - case x : KafkaDStreamCheckpointData => - restoredOffsets = x.savedOffsets - logInfo("Restored KafkaDStream offsets: " + savedOffsets) - } - } */ - def createReceiver(): NetworkReceiver[T] = { - new KafkaReceiver(host, port, groupId, topics, initialOffsets, storageLevel) + def getReceiver(): NetworkReceiver[T] = { + new KafkaReceiver(zkQuorum, groupId, topics, initialOffsets, storageLevel) .asInstanceOf[NetworkReceiver[T]] } } private[streaming] -class KafkaReceiver(host: String, port: Int, groupId: String, +class KafkaReceiver(zkQuorum: String, groupId: String, topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], storageLevel: StorageLevel) extends NetworkReceiver[Any] { @@ -111,8 +58,6 @@ class KafkaReceiver(host: String, port: Int, groupId: String, // Handles pushing data into the BlockManager lazy protected val blockGenerator = new BlockGenerator(storageLevel) - // Keeps track of the current offsets. Maps from (broker, topic, group, part) -> Offset - lazy val offsets = HashMap[KafkaPartitionKey, Long]() // Connection to Kafka var consumerConnector : ZookeeperConsumerConnector = null @@ -127,24 +72,23 @@ class KafkaReceiver(host: String, port: Int, groupId: String, // In case we are using multiple Threads to handle Kafka Messages val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) - val zooKeeperEndPoint = host + ":" + port logInfo("Starting Kafka Consumer Stream with group: " + groupId) logInfo("Initial offsets: " + initialOffsets.toString) - + // Zookeper connection properties val props = new Properties() - props.put("zk.connect", zooKeeperEndPoint) + props.put("zk.connect", zkQuorum) props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString) props.put("groupid", groupId) // Create the connection to the cluster - logInfo("Connecting to Zookeper: " + zooKeeperEndPoint) + logInfo("Connecting to Zookeper: " + zkQuorum) val consumerConfig = new ConsumerConfig(props) consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector] - logInfo("Connected to " + zooKeeperEndPoint) + logInfo("Connected to " + zkQuorum) - // Reset the Kafka offsets in case we are recovering from a failure - resetOffsets(initialOffsets) + // If specified, set the topic offset + setOffsets(initialOffsets) // Create Threads for each Topic/Message Stream we are listening val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder()) @@ -157,11 +101,11 @@ class KafkaReceiver(host: String, port: Int, groupId: String, } // Overwrites the offets in Zookeper. - private def resetOffsets(offsets: Map[KafkaPartitionKey, Long]) { + private def setOffsets(offsets: Map[KafkaPartitionKey, Long]) { offsets.foreach { case(key, offset) => val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic) val partitionName = key.brokerId + "-" + key.partId - updatePersistentPath(consumerConnector.zkClient, + updatePersistentPath(consumerConnector.zkClient, topicDirs.consumerOffsetDir + "/" + partitionName, offset.toString) } } @@ -172,29 +116,10 @@ class KafkaReceiver(host: String, port: Int, groupId: String, logInfo("Starting MessageHandler.") stream.takeWhile { msgAndMetadata => blockGenerator += msgAndMetadata.message - - // Updating the offet. The key is (broker, topic, group, partition). - val key = KafkaPartitionKey(msgAndMetadata.topicInfo.brokerId, msgAndMetadata.topic, - groupId, msgAndMetadata.topicInfo.partition.partId) - val offset = msgAndMetadata.topicInfo.getConsumeOffset - offsets.put(key, offset) - // logInfo("Handled message: " + (key, offset).toString) - // Keep on handling messages + true - } + } } } - - // NOT USED - Originally intended for fault-tolerance - // class KafkaDataHandler(receiver: KafkaReceiver, storageLevel: StorageLevel) - // extends BufferingBlockCreator[Any](receiver, storageLevel) { - - // override def createBlock(blockId: String, iterator: Iterator[Any]) : Block = { - // // Creates a new Block with Kafka-specific Metadata - // new Block(blockId, iterator, KafkaInputDStreamMetadata(System.currentTimeMillis, offsets.toMap)) - // } - - // } - } diff --git a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala index 8c322dd698099be5426008e0cadf792f002a736c..7385474963ebb70bfd7f8f86a16282e341824c80 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/NetworkInputDStream.scala @@ -20,7 +20,7 @@ import java.util.concurrent.ArrayBlockingQueue /** * Abstract class for defining any InputDStream that has to start a receiver on worker * nodes to receive external data. Specific implementations of NetworkInputDStream must - * define the createReceiver() function that creates the receiver object of type + * define the getReceiver() function that gets the receiver object of type * [[spark.streaming.dstream.NetworkReceiver]] that will be sent to the workers to receive * data. * @param ssc_ Streaming context that will execute this input stream @@ -34,11 +34,11 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming val id = ssc.getNewNetworkStreamId() /** - * Creates the receiver object that will be sent to the worker nodes + * Gets the receiver object that will be sent to the worker nodes * to receive data. This method needs to defined by any specific implementation * of a NetworkInputDStream. */ - def createReceiver(): NetworkReceiver[T] + def getReceiver(): NetworkReceiver[T] // Nothing to start or stop as both taken care of by the NetworkInputTracker. def start() {} @@ -46,8 +46,15 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming def stop() {} override def compute(validTime: Time): Option[RDD[T]] = { - val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) - Some(new BlockRDD[T](ssc.sc, blockIds)) + // If this is called for any time before the start time of the context, + // then this returns an empty RDD. This may happen when recovering from a + // master failure + if (validTime >= graph.startTime) { + val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) + Some(new BlockRDD[T](ssc.sc, blockIds)) + } else { + Some(new BlockRDD[T](ssc.sc, Array[String]())) + } } } diff --git a/streaming/src/main/scala/spark/streaming/dstream/PluggableInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/PluggableInputDStream.scala new file mode 100644 index 0000000000000000000000000000000000000000..3c2a81947b96b5b678b76e432cb865c79892c803 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/dstream/PluggableInputDStream.scala @@ -0,0 +1,13 @@ +package spark.streaming.dstream + +import spark.streaming.StreamingContext + +private[streaming] +class PluggableInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + receiver: NetworkReceiver[T]) extends NetworkInputDStream[T](ssc_) { + + def getReceiver(): NetworkReceiver[T] = { + receiver + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala index 024bf3bea47832867e2ce419e41e56786fa2c73d..6b310bc0b611c7a6260042e3ddc555a59bed28a2 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/QueueInputDStream.scala @@ -7,6 +7,7 @@ import scala.collection.mutable.Queue import scala.collection.mutable.ArrayBuffer import spark.streaming.{Time, StreamingContext} +private[streaming] class QueueInputDStream[T: ClassManifest]( @transient ssc: StreamingContext, val queue: Queue[RDD[T]], diff --git a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala index 04e6b69b7ba9d41a12a400d417df0150d9bd9c1c..1b2fa567795d4245567a42362a6f45e4f8d96e8b 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/RawInputDStream.scala @@ -25,7 +25,7 @@ class RawInputDStream[T: ClassManifest]( storageLevel: StorageLevel ) extends NetworkInputDStream[T](ssc_ ) with Logging { - def createReceiver(): NetworkReceiver[T] = { + def getReceiver(): NetworkReceiver[T] = { new RawNetworkReceiver(host, port, storageLevel).asInstanceOf[NetworkReceiver[T]] } } diff --git a/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala index 733d5c4a25271cbd741083dd392590bcb176298b..343b6915e79a36b02292bace746681474f0e3294 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -3,7 +3,7 @@ package spark.streaming.dstream import spark.streaming.StreamingContext._ import spark.RDD -import spark.rdd.CoGroupedRDD +import spark.rdd.{CoGroupedRDD, MapPartitionsRDD} import spark.Partitioner import spark.SparkContext._ import spark.storage.StorageLevel @@ -15,7 +15,8 @@ private[streaming] class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( parent: DStream[(K, V)], reduceFunc: (V, V) => V, - invReduceFunc: (V, V) => V, + invReduceFunc: (V, V) => V, + filterFunc: Option[((K, V)) => Boolean], _windowDuration: Duration, _slideDuration: Duration, partitioner: Partitioner @@ -87,21 +88,24 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( // // Get the RDDs of the reduced values in "old time steps" - val oldRDDs = reducedStream.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration) + val oldRDDs = + reducedStream.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration) logDebug("# old RDDs = " + oldRDDs.size) // Get the RDDs of the reduced values in "new time steps" - val newRDDs = reducedStream.slice(previousWindow.endTime + parent.slideDuration, currentWindow.endTime) + val newRDDs = + reducedStream.slice(previousWindow.endTime + parent.slideDuration, currentWindow.endTime) logDebug("# new RDDs = " + newRDDs.size) // Get the RDD of the reduced value of the previous window - val previousWindowRDD = getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]())) + val previousWindowRDD = + getOrCompute(previousWindow.endTime).getOrElse(ssc.sc.makeRDD(Seq[(K,V)]())) // Make the list of RDDs that needs to cogrouped together for reducing their reduced values val allRDDs = new ArrayBuffer[RDD[(K, V)]]() += previousWindowRDD ++= oldRDDs ++= newRDDs // Cogroup the reduced RDDs and merge the reduced values - val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(_, _)]]], partitioner) + val cogroupedRDD = new CoGroupedRDD[K](allRDDs.toSeq.asInstanceOf[Seq[RDD[(K, _)]]], partitioner) //val mergeValuesFunc = mergeValues(oldRDDs.size, newRDDs.size) _ val numOldValues = oldRDDs.size @@ -114,7 +118,9 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( // Getting reduced values "old time steps" that will be removed from current window val oldValues = (1 to numOldValues).map(i => seqOfValues(i)).filter(!_.isEmpty).map(_.head) // Getting reduced values "new time steps" - val newValues = (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) + val newValues = + (1 to numNewValues).map(i => seqOfValues(numOldValues + i)).filter(!_.isEmpty).map(_.head) + if (seqOfValues(0).isEmpty) { // If previous window's reduce value does not exist, then at least new values should exist if (newValues.isEmpty) { @@ -140,10 +146,12 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( val mergedValuesRDD = cogroupedRDD.asInstanceOf[RDD[(K,Seq[Seq[V]])]].mapValues(mergeValues) - Some(mergedValuesRDD) + if (filterFunc.isDefined) { + Some(mergedValuesRDD.filter(filterFunc.get)) + } else { + Some(mergedValuesRDD) + } } - - } diff --git a/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala index d42027092b620aa51b4ba36bc3a88310fdbd6299..4af839ad7f03d1e75dd19715a2cb0785a9a91628 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala @@ -15,7 +15,7 @@ class SocketInputDStream[T: ClassManifest]( storageLevel: StorageLevel ) extends NetworkInputDStream[T](ssc_) { - def createReceiver(): NetworkReceiver[T] = { + def getReceiver(): NetworkReceiver[T] = { new SocketReceiver(host, port, bytesToObjects, storageLevel) } } diff --git a/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala index b4506c74aa4e14daf720d560c52ae24b69c5aeac..db62955036e9c1a72a5e11571628a97ca858af53 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/StateDStream.scala @@ -48,8 +48,16 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S: ClassManifest]( //logDebug("Generating state RDD for time " + validTime) return Some(stateRDD) } - case None => { // If parent RDD does not exist, then return old state RDD - return Some(prevStateRDD) + case None => { // If parent RDD does not exist + + // Re-apply the update function to the old state RDD + val updateFuncLocal = updateFunc + val finalFunc = (iterator: Iterator[(K, S)]) => { + val i = iterator.map(t => (t._1, Seq[V](), Option(t._2))) + updateFuncLocal(i) + } + val stateRDD = prevStateRDD.mapPartitions(finalFunc, preservePartitioning) + return Some(stateRDD) } } } diff --git a/examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala similarity index 87% rename from examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala rename to streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala index 99ed4cdc1c12d13ef8367f2f5323fb7927b51ca2..c69749886289cd6e49401731f7970628e99115c3 100644 --- a/examples/src/main/scala/spark/streaming/examples/twitter/TwitterInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/TwitterInputDStream.scala @@ -1,12 +1,11 @@ -package spark.streaming.examples.twitter +package spark.streaming.dstream import spark._ import spark.streaming._ -import dstream.{NetworkReceiver, NetworkInputDStream} import storage.StorageLevel + import twitter4j._ import twitter4j.auth.BasicAuthorization -import collection.JavaConversions._ /* A stream of Twitter statuses, potentially filtered by one or more keywords. * @@ -14,19 +13,21 @@ import collection.JavaConversions._ * An optional set of string filters can be used to restrict the set of tweets. The Twitter API is * such that this may return a sampled subset of all tweets during each interval. */ +private[streaming] class TwitterInputDStream( @transient ssc_ : StreamingContext, username: String, password: String, filters: Seq[String], storageLevel: StorageLevel - ) extends NetworkInputDStream[Status](ssc_) { + ) extends NetworkInputDStream[Status](ssc_) { - override def createReceiver(): NetworkReceiver[Status] = { + override def getReceiver(): NetworkReceiver[Status] = { new TwitterReceiver(username, password, filters, storageLevel) } } +private[streaming] class TwitterReceiver( username: String, password: String, @@ -50,7 +51,7 @@ class TwitterReceiver( def onTrackLimitationNotice(i: Int) {} def onScrubGeo(l: Long, l1: Long) {} def onStallWarning(stallWarning: StallWarning) {} - def onException(e: Exception) {} + def onException(e: Exception) { stopOnError(e) } }) val query: FilterQuery = new FilterQuery diff --git a/streaming/src/main/scala/spark/streaming/receivers/ActorReceiver.scala b/streaming/src/main/scala/spark/streaming/receivers/ActorReceiver.scala new file mode 100644 index 0000000000000000000000000000000000000000..b3201d0b28d797c3a8d7e200e19b25fff878730b --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/receivers/ActorReceiver.scala @@ -0,0 +1,153 @@ +package spark.streaming.receivers + +import akka.actor.{ Actor, PoisonPill, Props, SupervisorStrategy } +import akka.actor.{ actorRef2Scala, ActorRef } +import akka.actor.{ PossiblyHarmful, OneForOneStrategy } + +import spark.storage.StorageLevel +import spark.streaming.dstream.NetworkReceiver + +import java.util.concurrent.atomic.AtomicInteger + +/** A helper with set of defaults for supervisor strategy **/ +object ReceiverSupervisorStrategy { + + import akka.util.duration._ + import akka.actor.SupervisorStrategy._ + + val defaultStrategy = OneForOneStrategy(maxNrOfRetries = 10, withinTimeRange = + 15 millis) { + case _: RuntimeException ⇒ Restart + case _: Exception ⇒ Escalate + } +} + +/** + * A receiver trait to be mixed in with your Actor to gain access to + * pushBlock API. + * + * @example {{{ + * class MyActor extends Actor with Receiver{ + * def receive { + * case anything :String ⇒ pushBlock(anything) + * } + * } + * //Can be plugged in actorStream as follows + * ssc.actorStream[String](Props(new MyActor),"MyActorReceiver") + * + * }}} + * + * @note An important point to note: + * Since Actor may exist outside the spark framework, It is thus user's responsibility + * to ensure the type safety, i.e parametrized type of push block and InputDStream + * should be same. + * + */ +trait Receiver { self: Actor ⇒ + def pushBlock[T: ClassManifest](iter: Iterator[T]) { + context.parent ! Data(iter) + } + + def pushBlock[T: ClassManifest](data: T) { + context.parent ! Data(data) + } + +} + +/** + * Statistics for querying the supervisor about state of workers + */ +case class Statistics(numberOfMsgs: Int, + numberOfWorkers: Int, + numberOfHiccups: Int, + otherInfo: String) + +/** Case class to receive data sent by child actors **/ +private[streaming] case class Data[T: ClassManifest](data: T) + +/** + * Provides Actors as receivers for receiving stream. + * + * As Actors can also be used to receive data from almost any stream source. + * A nice set of abstraction(s) for actors as receivers is already provided for + * a few general cases. It is thus exposed as an API where user may come with + * his own Actor to run as receiver for Spark Streaming input source. + * + * This starts a supervisor actor which starts workers and also provides + * [http://doc.akka.io/docs/akka/2.0.5/scala/fault-tolerance.html fault-tolerance]. + * + * Here's a way to start more supervisor/workers as its children. + * + * @example {{{ + * context.parent ! Props(new Supervisor) + * }}} OR {{{ + * context.parent ! Props(new Worker,"Worker") + * }}} + * + * + */ +private[streaming] class ActorReceiver[T: ClassManifest]( + props: Props, + name: String, + storageLevel: StorageLevel, + receiverSupervisorStrategy: SupervisorStrategy) + extends NetworkReceiver[T] { + + protected lazy val blocksGenerator: BlockGenerator = + new BlockGenerator(storageLevel) + + protected lazy val supervisor = env.actorSystem.actorOf(Props(new Supervisor), + "Supervisor" + streamId) + + private class Supervisor extends Actor { + + override val supervisorStrategy = receiverSupervisorStrategy + val worker = context.actorOf(props, name) + logInfo("Started receiver worker at:" + worker.path) + + val n: AtomicInteger = new AtomicInteger(0) + val hiccups: AtomicInteger = new AtomicInteger(0) + + def receive = { + + case Data(iter: Iterator[_]) ⇒ pushBlock(iter.asInstanceOf[Iterator[T]]) + + case Data(msg) ⇒ + blocksGenerator += msg.asInstanceOf[T] + n.incrementAndGet + + case props: Props ⇒ + val worker = context.actorOf(props) + logInfo("Started receiver worker at:" + worker.path) + sender ! worker + + case (props: Props, name: String) ⇒ + val worker = context.actorOf(props, name) + logInfo("Started receiver worker at:" + worker.path) + sender ! worker + + case _: PossiblyHarmful => hiccups.incrementAndGet() + + case _: Statistics ⇒ + val workers = context.children + sender ! Statistics(n.get, workers.size, hiccups.get, workers.mkString("\n")) + + } + } + + protected def pushBlock(iter: Iterator[T]) { + pushBlock("block-" + streamId + "-" + System.nanoTime(), + iter, null, storageLevel) + } + + protected def onStart() = { + blocksGenerator.start() + supervisor + logInfo("Supervision tree for receivers initialized at:" + supervisor.path) + } + + protected def onStop() = { + supervisor ! PoisonPill + } + +} diff --git a/streaming/src/main/scala/spark/streaming/receivers/ZeroMQReceiver.scala b/streaming/src/main/scala/spark/streaming/receivers/ZeroMQReceiver.scala new file mode 100644 index 0000000000000000000000000000000000000000..5533c3cf1ef8b5a316aea8b8fa2b9c88f60872fe --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/receivers/ZeroMQReceiver.scala @@ -0,0 +1,33 @@ +package spark.streaming.receivers + +import akka.actor.Actor +import akka.zeromq._ + +import spark.Logging + +/** + * A receiver to subscribe to ZeroMQ stream. + */ +private[streaming] class ZeroMQReceiver[T: ClassManifest](publisherUrl: String, + subscribe: Subscribe, + bytesToObjects: Seq[Seq[Byte]] ⇒ Iterator[T]) + extends Actor with Receiver with Logging { + + override def preStart() = context.system.newSocket(SocketType.Sub, Listener(self), + Connect(publisherUrl), subscribe) + + def receive: Receive = { + + case Connecting ⇒ logInfo("connecting ...") + + case m: ZMQMessage ⇒ + logDebug("Received message for:" + m.firstFrameAsString) + + //We ignore first frame for processing as it is the topic + val bytes = m.frames.tail.map(_.payload) + pushBlock(bytesToObjects(bytes)) + + case Closed ⇒ logInfo("received closed ") + + } +} diff --git a/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala new file mode 100644 index 0000000000000000000000000000000000000000..bdd9f4d7535ea0fb055d29ccc5eccf88dd83b958 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/util/MasterFailureTest.scala @@ -0,0 +1,392 @@ +package spark.streaming.util + +import spark.{Logging, RDD} +import spark.streaming._ +import spark.streaming.dstream.ForEachDStream +import StreamingContext._ + +import scala.util.Random +import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} + +import java.io.{File, ObjectInputStream, IOException} +import java.util.UUID + +import com.google.common.io.Files + +import org.apache.commons.io.FileUtils +import org.apache.hadoop.fs.{FileUtil, FileSystem, Path} +import org.apache.hadoop.conf.Configuration + + +private[streaming] +object MasterFailureTest extends Logging { + initLogging() + + @volatile var killed = false + @volatile var killCount = 0 + + def main(args: Array[String]) { + if (args.size < 2) { + println( + "Usage: MasterFailureTest <local/HDFS directory> <# batches> [<batch size in milliseconds>]") + System.exit(1) + } + val directory = args(0) + val numBatches = args(1).toInt + val batchDuration = if (args.size > 2) Milliseconds(args(2).toInt) else Seconds(1) + + println("\n\n========================= MAP TEST =========================\n\n") + testMap(directory, numBatches, batchDuration) + + println("\n\n================= UPDATE-STATE-BY-KEY TEST =================\n\n") + testUpdateStateByKey(directory, numBatches, batchDuration) + + println("\n\nSUCCESS\n\n") + } + + def testMap(directory: String, numBatches: Int, batchDuration: Duration) { + // Input: time=1 ==> [ 1 ] , time=2 ==> [ 2 ] , time=3 ==> [ 3 ] , ... + val input = (1 to numBatches).map(_.toString).toSeq + // Expected output: time=1 ==> [ 1 ] , time=2 ==> [ 2 ] , time=3 ==> [ 3 ] , ... + val expectedOutput = (1 to numBatches) + + val operation = (st: DStream[String]) => st.map(_.toInt) + + // Run streaming operation with multiple master failures + val output = testOperation(directory, batchDuration, input, operation, expectedOutput) + + logInfo("Expected output, size = " + expectedOutput.size) + logInfo(expectedOutput.mkString("[", ",", "]")) + logInfo("Output, size = " + output.size) + logInfo(output.mkString("[", ",", "]")) + + // Verify whether all the values of the expected output is present + // in the output + assert(output.distinct.toSet == expectedOutput.toSet) + } + + + def testUpdateStateByKey(directory: String, numBatches: Int, batchDuration: Duration) { + // Input: time=1 ==> [ a ] , time=2 ==> [ a, a ] , time=3 ==> [ a, a, a ] , ... + val input = (1 to numBatches).map(i => (1 to i).map(_ => "a").mkString(" ")).toSeq + // Expected output: time=1 ==> [ (a, 1) ] , time=2 ==> [ (a, 3) ] , time=3 ==> [ (a,6) ] , ... + val expectedOutput = (1L to numBatches).map(i => (1L to i).reduce(_ + _)).map(j => ("a", j)) + + val operation = (st: DStream[String]) => { + val updateFunc = (values: Seq[Long], state: Option[Long]) => { + Some(values.foldLeft(0L)(_ + _) + state.getOrElse(0L)) + } + st.flatMap(_.split(" ")) + .map(x => (x, 1L)) + .updateStateByKey[Long](updateFunc) + .checkpoint(batchDuration * 5) + } + + // Run streaming operation with multiple master failures + val output = testOperation(directory, batchDuration, input, operation, expectedOutput) + + logInfo("Expected output, size = " + expectedOutput.size + "\n" + expectedOutput) + logInfo("Output, size = " + output.size + "\n" + output) + + // Verify whether all the values in the output are among the expected output values + output.foreach(o => + assert(expectedOutput.contains(o), "Expected value " + o + " not found") + ) + + // Verify whether the last expected output value has been generated, there by + // confirming that none of the inputs have been missed + assert(output.last == expectedOutput.last) + } + + /** + * Tests stream operation with multiple master failures, and verifies whether the + * final set of output values is as expected or not. + */ + def testOperation[T: ClassManifest]( + directory: String, + batchDuration: Duration, + input: Seq[String], + operation: DStream[String] => DStream[T], + expectedOutput: Seq[T] + ): Seq[T] = { + + // Just making sure that the expected output does not have duplicates + assert(expectedOutput.distinct.toSet == expectedOutput.toSet) + + // Setup the stream computation with the given operation + val (ssc, checkpointDir, testDir) = setupStreams(directory, batchDuration, operation) + + // Start generating files in the a different thread + val fileGeneratingThread = new FileGeneratingThread(input, testDir, batchDuration.milliseconds) + fileGeneratingThread.start() + + // Run the streams and repeatedly kill it until the last expected output + // has been generated, or until it has run for twice the expected time + val lastExpectedOutput = expectedOutput.last + val maxTimeToRun = expectedOutput.size * batchDuration.milliseconds * 2 + val mergedOutput = runStreams(ssc, lastExpectedOutput, maxTimeToRun) + + // Delete directories + fileGeneratingThread.join() + val fs = checkpointDir.getFileSystem(new Configuration()) + fs.delete(checkpointDir, true) + fs.delete(testDir, true) + logInfo("Finished test after " + killCount + " failures") + mergedOutput + } + + /** + * Sets up the stream computation with the given operation, directory (local or HDFS), + * and batch duration. Returns the streaming context and the directory to which + * files should be written for testing. + */ + private def setupStreams[T: ClassManifest]( + directory: String, + batchDuration: Duration, + operation: DStream[String] => DStream[T] + ): (StreamingContext, Path, Path) = { + // Reset all state + reset() + + // Create the directories for this test + val uuid = UUID.randomUUID().toString + val rootDir = new Path(directory, uuid) + val fs = rootDir.getFileSystem(new Configuration()) + val checkpointDir = new Path(rootDir, "checkpoint") + val testDir = new Path(rootDir, "test") + fs.mkdirs(checkpointDir) + fs.mkdirs(testDir) + + // Setup the streaming computation with the given operation + System.clearProperty("spark.driver.port") + var ssc = new StreamingContext("local[4]", "MasterFailureTest", batchDuration) + ssc.checkpoint(checkpointDir.toString) + val inputStream = ssc.textFileStream(testDir.toString) + val operatedStream = operation(inputStream) + val outputStream = new TestOutputStream(operatedStream) + ssc.registerOutputStream(outputStream) + (ssc, checkpointDir, testDir) + } + + + /** + * Repeatedly starts and kills the streaming context until timed out or + * the last expected output is generated. Finally, return + */ + private def runStreams[T: ClassManifest]( + ssc_ : StreamingContext, + lastExpectedOutput: T, + maxTimeToRun: Long + ): Seq[T] = { + + var ssc = ssc_ + var totalTimeRan = 0L + var isLastOutputGenerated = false + var isTimedOut = false + val mergedOutput = new ArrayBuffer[T]() + val checkpointDir = ssc.checkpointDir + var batchDuration = ssc.graph.batchDuration + + while(!isLastOutputGenerated && !isTimedOut) { + // Get the output buffer + val outputBuffer = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[T]].output + def output = outputBuffer.flatMap(x => x) + + // Start the thread to kill the streaming after some time + killed = false + val killingThread = new KillingThread(ssc, batchDuration.milliseconds * 10) + killingThread.start() + + var timeRan = 0L + try { + // Start the streaming computation and let it run while ... + // (i) StreamingContext has not been shut down yet + // (ii) The last expected output has not been generated yet + // (iii) Its not timed out yet + System.clearProperty("spark.streaming.clock") + System.clearProperty("spark.driver.port") + ssc.start() + val startTime = System.currentTimeMillis() + while (!killed && !isLastOutputGenerated && !isTimedOut) { + Thread.sleep(100) + timeRan = System.currentTimeMillis() - startTime + isLastOutputGenerated = (!output.isEmpty && output.last == lastExpectedOutput) + isTimedOut = (timeRan + totalTimeRan > maxTimeToRun) + } + } catch { + case e: Exception => logError("Error running streaming context", e) + } + if (killingThread.isAlive) killingThread.interrupt() + ssc.stop() + + logInfo("Has been killed = " + killed) + logInfo("Is last output generated = " + isLastOutputGenerated) + logInfo("Is timed out = " + isTimedOut) + + // Verify whether the output of each batch has only one element or no element + // and then merge the new output with all the earlier output + mergedOutput ++= output + totalTimeRan += timeRan + logInfo("New output = " + output) + logInfo("Merged output = " + mergedOutput) + logInfo("Time ran = " + timeRan) + logInfo("Total time ran = " + totalTimeRan) + + if (!isLastOutputGenerated && !isTimedOut) { + val sleepTime = Random.nextInt(batchDuration.milliseconds.toInt * 10) + logInfo( + "\n-------------------------------------------\n" + + " Restarting stream computation in " + sleepTime + " ms " + + "\n-------------------------------------------\n" + ) + Thread.sleep(sleepTime) + // Recreate the streaming context from checkpoint + ssc = new StreamingContext(checkpointDir) + } + } + mergedOutput + } + + /** + * Verifies the output value are the same as expected. Since failures can lead to + * a batch being processed twice, a batches output may appear more than once + * consecutively. To avoid getting confused with those, we eliminate consecutive + * duplicate batch outputs of values from the `output`. As a result, the + * expected output should not have consecutive batches with the same values as output. + */ + private def verifyOutput[T: ClassManifest](output: Seq[T], expectedOutput: Seq[T]) { + // Verify whether expected outputs do not consecutive batches with same output + for (i <- 0 until expectedOutput.size - 1) { + assert(expectedOutput(i) != expectedOutput(i+1), + "Expected output has consecutive duplicate sequence of values") + } + + // Log the output + println("Expected output, size = " + expectedOutput.size) + println(expectedOutput.mkString("[", ",", "]")) + println("Output, size = " + output.size) + println(output.mkString("[", ",", "]")) + + // Match the output with the expected output + output.foreach(o => + assert(expectedOutput.contains(o), "Expected value " + o + " not found") + ) + } + + /** Resets counter to prepare for the test */ + private def reset() { + killed = false + killCount = 0 + } +} + +/** + * This is a output stream just for testing. All the output is collected into a + * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. + */ +private[streaming] +class TestOutputStream[T: ClassManifest]( + parent: DStream[T], + val output: ArrayBuffer[Seq[T]] = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]] + ) extends ForEachDStream[T]( + parent, + (rdd: RDD[T], t: Time) => { + val collected = rdd.collect() + output += collected + } + ) { + + // This is to clear the output buffer every it is read from a checkpoint + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + ois.defaultReadObject() + output.clear() + } +} + + +/** + * Thread to kill streaming context after a random period of time. + */ +private[streaming] +class KillingThread(ssc: StreamingContext, maxKillWaitTime: Long) extends Thread with Logging { + initLogging() + + override def run() { + try { + // If it is the first killing, then allow the first checkpoint to be created + var minKillWaitTime = if (MasterFailureTest.killCount == 0) 5000 else 2000 + val killWaitTime = minKillWaitTime + math.abs(Random.nextLong % maxKillWaitTime) + logInfo("Kill wait time = " + killWaitTime) + Thread.sleep(killWaitTime) + logInfo( + "\n---------------------------------------\n" + + "Killing streaming context after " + killWaitTime + " ms" + + "\n---------------------------------------\n" + ) + if (ssc != null) { + ssc.stop() + MasterFailureTest.killed = true + MasterFailureTest.killCount += 1 + } + logInfo("Killing thread finished normally") + } catch { + case ie: InterruptedException => logInfo("Killing thread interrupted") + case e: Exception => logWarning("Exception in killing thread", e) + } + + } +} + + +/** + * Thread to generate input files periodically with the desired text. + */ +private[streaming] +class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long) + extends Thread with Logging { + initLogging() + + override def run() { + val localTestDir = Files.createTempDir() + var fs = testDir.getFileSystem(new Configuration()) + val maxTries = 3 + try { + Thread.sleep(5000) // To make sure that all the streaming context has been set up + for (i <- 0 until input.size) { + // Write the data to a local file and then move it to the target test directory + val localFile = new File(localTestDir, (i+1).toString) + val hadoopFile = new Path(testDir, (i+1).toString) + FileUtils.writeStringToFile(localFile, input(i).toString + "\n") + var tries = 0 + var done = false + while (!done && tries < maxTries) { + tries += 1 + try { + fs.copyFromLocalFile(new Path(localFile.toString), hadoopFile) + done = true + } catch { + case ioe: IOException => { + fs = testDir.getFileSystem(new Configuration()) + logWarning("Attempt " + tries + " at generating file " + hadoopFile + " failed.", ioe) + } + } + } + if (!done) + logError("Could not generate file " + hadoopFile) + else + logInfo("Generated file " + hadoopFile + " at " + System.currentTimeMillis) + Thread.sleep(interval) + localFile.delete() + } + logInfo("File generating thread finished normally") + } catch { + case ie: InterruptedException => logInfo("File generating thread interrupted") + case e: Exception => logWarning("File generating in killing thread", e) + } finally { + fs.close() + } + } +} + + diff --git a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala index db715cc295205cfb76ac0b416dfaf039eaaa31f8..8e10276deb9056301f98712637208db3ce2a74b7 100644 --- a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala @@ -3,9 +3,9 @@ package spark.streaming.util private[streaming] class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => Unit) { - val minPollTime = 25L + private val minPollTime = 25L - val pollTime = { + private val pollTime = { if (period / 10.0 > minPollTime) { (period / 10.0).toLong } else { @@ -13,11 +13,20 @@ class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => } } - val thread = new Thread() { + private val thread = new Thread() { override def run() { loop } } - var nextTime = 0L + private var nextTime = 0L + + def getStartTime(): Long = { + (math.floor(clock.currentTime.toDouble / period) + 1).toLong * period + } + + def getRestartTime(originalStartTime: Long): Long = { + val gap = clock.currentTime - originalStartTime + (math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime + } def start(startTime: Long): Long = { nextTime = startTime @@ -26,21 +35,14 @@ class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => } def start(): Long = { - val startTime = (math.floor(clock.currentTime.toDouble / period) + 1).toLong * period - start(startTime) + start(getStartTime()) } - def restart(originalStartTime: Long): Long = { - val gap = clock.currentTime - originalStartTime - val newStartTime = (math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime - start(newStartTime) - } - - def stop() { + def stop() { thread.interrupt() } - def loop() { + private def loop() { try { while (true) { clock.waitTillTime(nextTime) diff --git a/streaming/src/test/java/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/spark/streaming/JavaAPISuite.java index 79d60934297f3814ba41fca8d90f2ea76562d44c..3bed500f73e84548a2f470a414ea68cfbb3df1ac 100644 --- a/streaming/src/test/java/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/spark/streaming/JavaAPISuite.java @@ -11,7 +11,10 @@ import org.junit.Before; import org.junit.Test; import scala.Tuple2; import spark.HashPartitioner; +import spark.api.java.JavaPairRDD; import spark.api.java.JavaRDD; +import spark.api.java.JavaRDDLike; +import spark.api.java.JavaPairRDD; import spark.api.java.JavaSparkContext; import spark.api.java.function.*; import spark.storage.StorageLevel; @@ -21,10 +24,16 @@ import spark.streaming.api.java.JavaStreamingContext; import spark.streaming.JavaTestUtils; import spark.streaming.JavaCheckpointTestUtils; import spark.streaming.dstream.KafkaPartitionKey; +import spark.streaming.InputStreamsSuite; import java.io.*; import java.util.*; +import akka.actor.Props; +import akka.zeromq.Subscribe; + + + // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; // see http://stackoverflow.com/questions/758570/. @@ -33,8 +42,9 @@ public class JavaAPISuite implements Serializable { @Before public void setUp() { - ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); - ssc.checkpoint("checkpoint", new Duration(1000)); + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + ssc.checkpoint("checkpoint"); } @After @@ -133,29 +143,6 @@ public class JavaAPISuite implements Serializable { assertOrderInvariantEquals(expected, result); } - @Test - public void testTumble() { - List<List<Integer>> inputData = Arrays.asList( - Arrays.asList(1,2,3), - Arrays.asList(4,5,6), - Arrays.asList(7,8,9), - Arrays.asList(10,11,12), - Arrays.asList(13,14,15), - Arrays.asList(16,17,18)); - - List<List<Integer>> expected = Arrays.asList( - Arrays.asList(1,2,3,4,5,6), - Arrays.asList(7,8,9,10,11,12), - Arrays.asList(13,14,15,16,17,18)); - - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream windowed = stream.tumble(new Duration(2000)); - JavaTestUtils.attachTestOutputStream(windowed); - List<List<Integer>> result = JavaTestUtils.runStreams(ssc, 6, 3); - - assertOrderInvariantEquals(expected, result); - } - @Test public void testFilter() { List<List<String>> inputData = Arrays.asList( @@ -315,8 +302,9 @@ public class JavaAPISuite implements Serializable { Arrays.asList(6,7,8), Arrays.asList(9,10,11)); - JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream transformed = stream.transform(new Function<JavaRDD<Integer>, JavaRDD<Integer>>() { + JavaDStream<Integer> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaDStream<Integer> transformed = + stream.transform(new Function<JavaRDD<Integer>, JavaRDD<Integer>>() { @Override public JavaRDD<Integer> call(JavaRDD<Integer> in) throws Exception { return in.map(new Function<Integer, Integer>() { @@ -506,6 +494,141 @@ public class JavaAPISuite implements Serializable { new Tuple2<String, Integer>("new york", 3), new Tuple2<String, Integer>("new york", 1))); + @Test + public void testPairMap() { // Maps pair -> pair of different type + List<List<Tuple2<String, Integer>>> inputData = stringIntKVStream; + + List<List<Tuple2<Integer, String>>> expected = Arrays.asList( + Arrays.asList( + new Tuple2<Integer, String>(1, "california"), + new Tuple2<Integer, String>(3, "california"), + new Tuple2<Integer, String>(4, "new york"), + new Tuple2<Integer, String>(1, "new york")), + Arrays.asList( + new Tuple2<Integer, String>(5, "california"), + new Tuple2<Integer, String>(5, "california"), + new Tuple2<Integer, String>(3, "new york"), + new Tuple2<Integer, String>(1, "new york"))); + + JavaDStream<Tuple2<String, Integer>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream); + JavaPairDStream<Integer, String> reversed = pairStream.map( + new PairFunction<Tuple2<String, Integer>, Integer, String>() { + @Override + public Tuple2<Integer, String> call(Tuple2<String, Integer> in) throws Exception { + return in.swap(); + } + }); + + JavaTestUtils.attachTestOutputStream(reversed); + List<List<Tuple2<Integer, String>>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairMapPartitions() { // Maps pair -> pair of different type + List<List<Tuple2<String, Integer>>> inputData = stringIntKVStream; + + List<List<Tuple2<Integer, String>>> expected = Arrays.asList( + Arrays.asList( + new Tuple2<Integer, String>(1, "california"), + new Tuple2<Integer, String>(3, "california"), + new Tuple2<Integer, String>(4, "new york"), + new Tuple2<Integer, String>(1, "new york")), + Arrays.asList( + new Tuple2<Integer, String>(5, "california"), + new Tuple2<Integer, String>(5, "california"), + new Tuple2<Integer, String>(3, "new york"), + new Tuple2<Integer, String>(1, "new york"))); + + JavaDStream<Tuple2<String, Integer>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream); + JavaPairDStream<Integer, String> reversed = pairStream.mapPartitions( + new PairFlatMapFunction<Iterator<Tuple2<String, Integer>>, Integer, String>() { + @Override + public Iterable<Tuple2<Integer, String>> call(Iterator<Tuple2<String, Integer>> in) throws Exception { + LinkedList<Tuple2<Integer, String>> out = new LinkedList<Tuple2<Integer, String>>(); + while (in.hasNext()) { + Tuple2<String, Integer> next = in.next(); + out.add(next.swap()); + } + return out; + } + }); + + JavaTestUtils.attachTestOutputStream(reversed); + List<List<Tuple2<Integer, String>>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairMap2() { // Maps pair -> single + List<List<Tuple2<String, Integer>>> inputData = stringIntKVStream; + + List<List<Integer>> expected = Arrays.asList( + Arrays.asList(1, 3, 4, 1), + Arrays.asList(5, 5, 3, 1)); + + JavaDStream<Tuple2<String, Integer>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream); + JavaDStream<Integer> reversed = pairStream.map( + new Function<Tuple2<String, Integer>, Integer>() { + @Override + public Integer call(Tuple2<String, Integer> in) throws Exception { + return in._2(); + } + }); + + JavaTestUtils.attachTestOutputStream(reversed); + List<List<Tuple2<Integer, String>>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairToPairFlatMapWithChangingTypes() { // Maps pair -> pair + List<List<Tuple2<String, Integer>>> inputData = Arrays.asList( + Arrays.asList( + new Tuple2<String, Integer>("hi", 1), + new Tuple2<String, Integer>("ho", 2)), + Arrays.asList( + new Tuple2<String, Integer>("hi", 1), + new Tuple2<String, Integer>("ho", 2))); + + List<List<Tuple2<Integer, String>>> expected = Arrays.asList( + Arrays.asList( + new Tuple2<Integer, String>(1, "h"), + new Tuple2<Integer, String>(1, "i"), + new Tuple2<Integer, String>(2, "h"), + new Tuple2<Integer, String>(2, "o")), + Arrays.asList( + new Tuple2<Integer, String>(1, "h"), + new Tuple2<Integer, String>(1, "i"), + new Tuple2<Integer, String>(2, "h"), + new Tuple2<Integer, String>(2, "o"))); + + JavaDStream<Tuple2<String, Integer>> stream = + JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream); + JavaPairDStream<Integer, String> flatMapped = pairStream.flatMap( + new PairFlatMapFunction<Tuple2<String, Integer>, Integer, String>() { + @Override + public Iterable<Tuple2<Integer, String>> call(Tuple2<String, Integer> in) throws Exception { + List<Tuple2<Integer, String>> out = new LinkedList<Tuple2<Integer, String>>(); + for (Character s : in._1().toCharArray()) { + out.add(new Tuple2<Integer, String>(in._2(), s.toString())); + } + return out; + } + }); + JavaTestUtils.attachTestOutputStream(flatMapped); + List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + @Test public void testPairGroupByKey() { List<List<Tuple2<String, String>>> inputData = stringStringKVStream; @@ -570,7 +693,7 @@ public class JavaAPISuite implements Serializable { JavaPairDStream<String, Integer> combined = pairStream.<Integer>combineByKey( new Function<Integer, Integer>() { - @Override + @Override public Integer call(Integer i) throws Exception { return i; } @@ -583,50 +706,73 @@ public class JavaAPISuite implements Serializable { } @Test - public void testCountByKey() { - List<List<Tuple2<String, String>>> inputData = stringStringKVStream; + public void testCountByValue() { + List<List<String>> inputData = Arrays.asList( + Arrays.asList("hello", "world"), + Arrays.asList("hello", "moon"), + Arrays.asList("hello")); List<List<Tuple2<String, Long>>> expected = Arrays.asList( - Arrays.asList( - new Tuple2<String, Long>("california", 2L), - new Tuple2<String, Long>("new york", 2L)), - Arrays.asList( - new Tuple2<String, Long>("california", 2L), - new Tuple2<String, Long>("new york", 2L))); - - JavaDStream<Tuple2<String, String>> stream = JavaTestUtils.attachTestInputStream( - ssc, inputData, 1); - JavaPairDStream<String, String> pairStream = JavaPairDStream.fromJavaDStream(stream); + Arrays.asList( + new Tuple2<String, Long>("hello", 1L), + new Tuple2<String, Long>("world", 1L)), + Arrays.asList( + new Tuple2<String, Long>("hello", 1L), + new Tuple2<String, Long>("moon", 1L)), + Arrays.asList( + new Tuple2<String, Long>("hello", 1L))); - JavaPairDStream<String, Long> counted = pairStream.countByKey(); + JavaDStream<String> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream<String, Long> counted = stream.countByValue(); JavaTestUtils.attachTestOutputStream(counted); - List<List<Tuple2<String, Long>>> result = JavaTestUtils.runStreams(ssc, 2, 2); + List<List<Tuple2<String, Long>>> result = JavaTestUtils.runStreams(ssc, 3, 3); Assert.assertEquals(expected, result); } @Test public void testGroupByKeyAndWindow() { - List<List<Tuple2<String, String>>> inputData = stringStringKVStream; + List<List<Tuple2<String, Integer>>> inputData = stringIntKVStream; - List<List<Tuple2<String, List<String>>>> expected = Arrays.asList( - Arrays.asList(new Tuple2<String, List<String>>("california", Arrays.asList("dodgers", "giants")), - new Tuple2<String, List<String>>("new york", Arrays.asList("yankees", "mets"))), - Arrays.asList(new Tuple2<String, List<String>>("california", - Arrays.asList("sharks", "ducks", "dodgers", "giants")), - new Tuple2<String, List<String>>("new york", Arrays.asList("rangers", "islanders", "yankees", "mets"))), - Arrays.asList(new Tuple2<String, List<String>>("california", Arrays.asList("sharks", "ducks")), - new Tuple2<String, List<String>>("new york", Arrays.asList("rangers", "islanders")))); + List<List<Tuple2<String, List<Integer>>>> expected = Arrays.asList( + Arrays.asList( + new Tuple2<String, List<Integer>>("california", Arrays.asList(1, 3)), + new Tuple2<String, List<Integer>>("new york", Arrays.asList(1, 4)) + ), + Arrays.asList( + new Tuple2<String, List<Integer>>("california", Arrays.asList(1, 3, 5, 5)), + new Tuple2<String, List<Integer>>("new york", Arrays.asList(1, 1, 3, 4)) + ), + Arrays.asList( + new Tuple2<String, List<Integer>>("california", Arrays.asList(5, 5)), + new Tuple2<String, List<Integer>>("new york", Arrays.asList(1, 3)) + ) + ); - JavaDStream<Tuple2<String, String>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaPairDStream<String, String> pairStream = JavaPairDStream.fromJavaDStream(stream); + JavaDStream<Tuple2<String, Integer>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream<String, List<String>> groupWindowed = + JavaPairDStream<String, List<Integer>> groupWindowed = pairStream.groupByKeyAndWindow(new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(groupWindowed); - List<List<Tuple2<String, List<String>>>> result = JavaTestUtils.runStreams(ssc, 3, 3); + List<List<Tuple2<String, List<Integer>>>> result = JavaTestUtils.runStreams(ssc, 3, 3); - Assert.assertEquals(expected, result); + assert(result.size() == expected.size()); + for (int i = 0; i < result.size(); i++) { + assert(convert(result.get(i)).equals(convert(expected.get(i)))); + } + } + + private HashSet<Tuple2<String, HashSet<Integer>>> convert(List<Tuple2<String, List<Integer>>> listOfTuples) { + List<Tuple2<String, HashSet<Integer>>> newListOfTuples = new ArrayList<Tuple2<String, HashSet<Integer>>>(); + for (Tuple2<String, List<Integer>> tuple: listOfTuples) { + newListOfTuples.add(convert(tuple)); + } + return new HashSet<Tuple2<String, HashSet<Integer>>>(newListOfTuples); + } + + private Tuple2<String, HashSet<Integer>> convert(Tuple2<String, List<Integer>> tuple) { + return new Tuple2<String, HashSet<Integer>>(tuple._1(), new HashSet<Integer>(tuple._2())); } @Test @@ -668,7 +814,7 @@ public class JavaAPISuite implements Serializable { JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream<String, Integer> updated = pairStream.updateStateByKey( - new Function2<List<Integer>, Optional<Integer>, Optional<Integer>>(){ + new Function2<List<Integer>, Optional<Integer>, Optional<Integer>>() { @Override public Optional<Integer> call(List<Integer> values, Optional<Integer> state) { int out = 0; @@ -680,7 +826,7 @@ public class JavaAPISuite implements Serializable { } return Optional.of(out); } - }); + }); JavaTestUtils.attachTestOutputStream(updated); List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -711,26 +857,28 @@ public class JavaAPISuite implements Serializable { } @Test - public void testCountByKeyAndWindow() { - List<List<Tuple2<String, String>>> inputData = stringStringKVStream; + public void testCountByValueAndWindow() { + List<List<String>> inputData = Arrays.asList( + Arrays.asList("hello", "world"), + Arrays.asList("hello", "moon"), + Arrays.asList("hello")); List<List<Tuple2<String, Long>>> expected = Arrays.asList( Arrays.asList( - new Tuple2<String, Long>("california", 2L), - new Tuple2<String, Long>("new york", 2L)), + new Tuple2<String, Long>("hello", 1L), + new Tuple2<String, Long>("world", 1L)), Arrays.asList( - new Tuple2<String, Long>("california", 4L), - new Tuple2<String, Long>("new york", 4L)), + new Tuple2<String, Long>("hello", 2L), + new Tuple2<String, Long>("world", 1L), + new Tuple2<String, Long>("moon", 1L)), Arrays.asList( - new Tuple2<String, Long>("california", 2L), - new Tuple2<String, Long>("new york", 2L))); + new Tuple2<String, Long>("hello", 2L), + new Tuple2<String, Long>("moon", 1L))); - JavaDStream<Tuple2<String, String>> stream = JavaTestUtils.attachTestInputStream( + JavaDStream<String> stream = JavaTestUtils.attachTestInputStream( ssc, inputData, 1); - JavaPairDStream<String, String> pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream<String, Long> counted = - pairStream.countByKeyAndWindow(new Duration(2000), new Duration(1000)); + stream.countByValueAndWindow(new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(counted); List<List<Tuple2<String, Long>>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -738,6 +886,90 @@ public class JavaAPISuite implements Serializable { } @Test + public void testPairTransform() { + List<List<Tuple2<Integer, Integer>>> inputData = Arrays.asList( + Arrays.asList( + new Tuple2<Integer, Integer>(3, 5), + new Tuple2<Integer, Integer>(1, 5), + new Tuple2<Integer, Integer>(4, 5), + new Tuple2<Integer, Integer>(2, 5)), + Arrays.asList( + new Tuple2<Integer, Integer>(2, 5), + new Tuple2<Integer, Integer>(3, 5), + new Tuple2<Integer, Integer>(4, 5), + new Tuple2<Integer, Integer>(1, 5))); + + List<List<Tuple2<Integer, Integer>>> expected = Arrays.asList( + Arrays.asList( + new Tuple2<Integer, Integer>(1, 5), + new Tuple2<Integer, Integer>(2, 5), + new Tuple2<Integer, Integer>(3, 5), + new Tuple2<Integer, Integer>(4, 5)), + Arrays.asList( + new Tuple2<Integer, Integer>(1, 5), + new Tuple2<Integer, Integer>(2, 5), + new Tuple2<Integer, Integer>(3, 5), + new Tuple2<Integer, Integer>(4, 5))); + + JavaDStream<Tuple2<Integer, Integer>> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream<Integer, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream<Integer, Integer> sorted = pairStream.transform( + new Function<JavaPairRDD<Integer, Integer>, JavaPairRDD<Integer, Integer>>() { + @Override + public JavaPairRDD<Integer, Integer> call(JavaPairRDD<Integer, Integer> in) throws Exception { + return in.sortByKey(); + } + }); + + JavaTestUtils.attachTestOutputStream(sorted); + List<List<Tuple2<String, String>>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + + @Test + public void testPairToNormalRDDTransform() { + List<List<Tuple2<Integer, Integer>>> inputData = Arrays.asList( + Arrays.asList( + new Tuple2<Integer, Integer>(3, 5), + new Tuple2<Integer, Integer>(1, 5), + new Tuple2<Integer, Integer>(4, 5), + new Tuple2<Integer, Integer>(2, 5)), + Arrays.asList( + new Tuple2<Integer, Integer>(2, 5), + new Tuple2<Integer, Integer>(3, 5), + new Tuple2<Integer, Integer>(4, 5), + new Tuple2<Integer, Integer>(1, 5))); + + List<List<Integer>> expected = Arrays.asList( + Arrays.asList(3,1,4,2), + Arrays.asList(2,3,4,1)); + + JavaDStream<Tuple2<Integer, Integer>> stream = JavaTestUtils.attachTestInputStream( + ssc, inputData, 1); + JavaPairDStream<Integer, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaDStream<Integer> firstParts = pairStream.transform( + new Function<JavaPairRDD<Integer, Integer>, JavaRDD<Integer>>() { + @Override + public JavaRDD<Integer> call(JavaPairRDD<Integer, Integer> in) throws Exception { + return in.map(new Function<Tuple2<Integer, Integer>, Integer>() { + @Override + public Integer call(Tuple2<Integer, Integer> in) { + return in._1(); + } + }); + } + }); + + JavaTestUtils.attachTestOutputStream(firstParts); + List<List<Integer>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected, result); + } + public void testMapValues() { List<List<Tuple2<String, String>>> inputData = stringStringKVStream; @@ -911,9 +1143,8 @@ public class JavaAPISuite implements Serializable { Arrays.asList(1,4), Arrays.asList(8,7)); - File tempDir = Files.createTempDir(); - ssc.checkpoint(tempDir.getAbsolutePath(), new Duration(1000)); + ssc.checkpoint(tempDir.getAbsolutePath()); JavaDStream stream = JavaCheckpointTestUtils.attachTestInputStream(ssc, inputData, 1); JavaDStream letterCount = stream.map(new Function<String, Integer>() { @@ -927,14 +1158,16 @@ public class JavaAPISuite implements Serializable { assertOrderInvariantEquals(expectedInitial, initialResult); Thread.sleep(1000); - ssc.stop(); + ssc = new JavaStreamingContext(tempDir.getAbsolutePath()); - ssc.start(); - List<List<Integer>> finalResult = JavaCheckpointTestUtils.runStreams(ssc, 2, 2); - assertOrderInvariantEquals(expectedFinal, finalResult); + // Tweak to take into consideration that the last batch before failure + // will be re-processed after recovery + List<List<Integer>> finalResult = JavaCheckpointTestUtils.runStreams(ssc, 2, 3); + assertOrderInvariantEquals(expectedFinal, finalResult.subList(1, 3)); } + /** TEST DISABLED: Pending a discussion about checkpoint() semantics with TD @Test public void testCheckpointofIndividualStream() throws InterruptedException { @@ -971,19 +1204,19 @@ public class JavaAPISuite implements Serializable { public void testKafkaStream() { HashMap<String, Integer> topics = Maps.newHashMap(); HashMap<KafkaPartitionKey, Long> offsets = Maps.newHashMap(); - JavaDStream test1 = ssc.kafkaStream("localhost", 12345, "group", topics); - JavaDStream test2 = ssc.kafkaStream("localhost", 12345, "group", topics, offsets); - JavaDStream test3 = ssc.kafkaStream("localhost", 12345, "group", topics, offsets, + JavaDStream test1 = ssc.kafkaStream("localhost:12345", "group", topics); + JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics, offsets); + JavaDStream test3 = ssc.kafkaStream("localhost:12345", "group", topics, offsets, StorageLevel.MEMORY_AND_DISK()); } @Test - public void testNetworkTextStream() { - JavaDStream test = ssc.networkTextStream("localhost", 12345); + public void testSocketTextStream() { + JavaDStream test = ssc.socketTextStream("localhost", 12345); } @Test - public void testNetworkString() { + public void testSocketString() { class Converter extends Function<InputStream, Iterable<String>> { public Iterable<String> call(InputStream in) { BufferedReader reader = new BufferedReader(new InputStreamReader(in)); @@ -999,7 +1232,7 @@ public class JavaAPISuite implements Serializable { } } - JavaDStream test = ssc.networkStream( + JavaDStream test = ssc.socketStream( "localhost", 12345, new Converter(), @@ -1012,13 +1245,13 @@ public class JavaAPISuite implements Serializable { } @Test - public void testRawNetworkStream() { - JavaDStream test = ssc.rawNetworkStream("localhost", 12345); + public void testRawSocketStream() { + JavaDStream test = ssc.rawSocketStream("localhost", 12345); } @Test public void testFlumeStream() { - JavaDStream test = ssc.flumeStream("localhost", 12345); + JavaDStream test = ssc.flumeStream("localhost", 12345, StorageLevel.MEMORY_ONLY()); } @Test @@ -1026,4 +1259,25 @@ public class JavaAPISuite implements Serializable { JavaPairDStream<String, String> foo = ssc.<String, String, SequenceFileInputFormat>fileStream("/tmp/foo"); } + + @Test + public void testTwitterStream() { + String[] filters = new String[] { "good", "bad", "ugly" }; + JavaDStream test = ssc.twitterStream("username", "password", filters, StorageLevel.MEMORY_ONLY()); + } + + @Test + public void testActorStream() { + JavaDStream test = ssc.actorStream((Props)null, "TestActor", StorageLevel.MEMORY_ONLY()); + } + + @Test + public void testZeroMQStream() { + JavaDStream test = ssc.zeroMQStream("url", (Subscribe) null, new Function<byte[][], Iterable<String>>() { + @Override + public Iterable<String> call(byte[][] b) throws Exception { + return null; + } + }); + } } diff --git a/streaming/src/test/java/spark/streaming/JavaTestUtils.scala b/streaming/src/test/java/spark/streaming/JavaTestUtils.scala index 56349837e5e6ba5895c2de50103a24e9d9e94bfe..64a7e7cbf9a367285b9a3704f781235dac099bbb 100644 --- a/streaming/src/test/java/spark/streaming/JavaTestUtils.scala +++ b/streaming/src/test/java/spark/streaming/JavaTestUtils.scala @@ -31,8 +31,9 @@ trait JavaTestBase extends TestSuiteBase { * Attach a provided stream to it's associated StreamingContext as a * [[spark.streaming.TestOutputStream]]. **/ - def attachTestOutputStream[T, This <: spark.streaming.api.java.JavaDStreamLike[T,This]]( - dstream: JavaDStreamLike[T, This]) = { + def attachTestOutputStream[T, This <: spark.streaming.api.java.JavaDStreamLike[T, This, R], + R <: spark.api.java.JavaRDDLike[T, R]]( + dstream: JavaDStreamLike[T, This, R]) = { implicit val cm: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] val ostream = new TestOutputStream(dstream.dstream, @@ -57,6 +58,7 @@ trait JavaTestBase extends TestSuiteBase { } object JavaTestUtils extends JavaTestBase { + override def maxWaitTimeMillis = 20000 } diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index edfa1243fa2bc54ca72d5af3b4d0ad7df9263a56..59c445e63f7974163e77d47257644474dd14e62f 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -1,5 +1,6 @@ # Set everything to be logged to the file streaming/target/unit-tests.log log4j.rootCategory=INFO, file +# log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false log4j.appender.file.file=streaming/target/unit-tests.log diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala index 4a036f071074f8ed4228e61395e12e585be17d07..8fce91853c77eed0f5d3e5e48aa2e93f13988e3e 100644 --- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala @@ -6,6 +6,8 @@ import util.ManualClock class BasicOperationsSuite extends TestSuiteBase { + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + override def framework() = "BasicOperationsSuite" after { @@ -22,7 +24,7 @@ class BasicOperationsSuite extends TestSuiteBase { ) } - test("flatmap") { + test("flatMap") { val input = Seq(1 to 4, 5 to 8, 9 to 12) testOperation( input, @@ -86,6 +88,23 @@ class BasicOperationsSuite extends TestSuiteBase { ) } + test("count") { + testOperation( + Seq(1 to 1, 1 to 2, 1 to 3, 1 to 4), + (s: DStream[Int]) => s.count(), + Seq(Seq(1L), Seq(2L), Seq(3L), Seq(4L)) + ) + } + + test("countByValue") { + testOperation( + Seq(1 to 1, Seq(1, 1, 1), 1 to 2, Seq(1, 1, 2, 2)), + (s: DStream[Int]) => s.countByValue(), + Seq(Seq((1, 1L)), Seq((1, 3L)), Seq((1, 1L), (2, 1L)), Seq((2, 2L), (1, 2L))), + true + ) + } + test("mapValues") { testOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), @@ -165,6 +184,71 @@ class BasicOperationsSuite extends TestSuiteBase { testOperation(inputData, updateStateOperation, outputData, true) } + test("updateStateByKey - object lifecycle") { + val inputData = + Seq( + Seq("a","b"), + null, + Seq("a","c","a"), + Seq("c"), + null, + null + ) + + val outputData = + Seq( + Seq(("a", 1), ("b", 1)), + Seq(("a", 1), ("b", 1)), + Seq(("a", 3), ("c", 1)), + Seq(("a", 3), ("c", 2)), + Seq(("c", 2)), + Seq() + ) + + val updateStateOperation = (s: DStream[String]) => { + class StateObject(var counter: Int = 0, var expireCounter: Int = 0) extends Serializable + + // updateFunc clears a state when a StateObject is seen without new values twice in a row + val updateFunc = (values: Seq[Int], state: Option[StateObject]) => { + val stateObj = state.getOrElse(new StateObject) + values.foldLeft(0)(_ + _) match { + case 0 => stateObj.expireCounter += 1 // no new values + case n => { // has new values, increment and reset expireCounter + stateObj.counter += n + stateObj.expireCounter = 0 + } + } + stateObj.expireCounter match { + case 2 => None // seen twice with no new values, give it the boot + case _ => Option(stateObj) + } + } + s.map(x => (x, 1)).updateStateByKey[StateObject](updateFunc).mapValues(_.counter) + } + + testOperation(inputData, updateStateOperation, outputData, true) + } + + test("slice") { + val ssc = new StreamingContext("local[2]", "BasicOperationSuite", Seconds(1)) + val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4)) + val stream = new TestInputStream[Int](ssc, input, 2) + ssc.registerInputStream(stream) + stream.foreach(_ => {}) // Dummy output stream + ssc.start() + Thread.sleep(2000) + def getInputFromSlice(fromMillis: Long, toMillis: Long) = { + stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet + } + + assert(getInputFromSlice(0, 1000) == Set(1)) + assert(getInputFromSlice(0, 2000) == Set(1, 2)) + assert(getInputFromSlice(1000, 2000) == Set(1, 2)) + assert(getInputFromSlice(2000, 4000) == Set(2, 3, 4)) + ssc.stop() + Thread.sleep(1000) + } + test("forgetting of RDDs - map and window operations") { assert(batchDuration === Seconds(1), "Batch duration has changed from 1 second") diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index 563a7d14587b591aadb5d7587348607b28d0bf46..cac86deeaf3492cb298c26d9812cc82e71fd83cf 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -1,5 +1,6 @@ package spark.streaming +import dstream.FileInputDStream import spark.streaming.StreamingContext._ import java.io.File import runtime.RichInt @@ -7,9 +8,19 @@ import org.scalatest.BeforeAndAfter import org.apache.commons.io.FileUtils import collection.mutable.{SynchronizedBuffer, ArrayBuffer} import util.{Clock, ManualClock} +import scala.util.Random +import com.google.common.io.Files + +/** + * This test suites tests the checkpointing functionality of DStreams - + * the checkpointing of a DStream's RDDs as well as the checkpointing of + * the whole DStream graph. + */ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + before { FileUtils.deleteDirectory(new File(checkpointDir)) } @@ -28,21 +39,18 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { override def batchDuration = Milliseconds(500) - override def checkpointInterval = batchDuration - override def actuallyWait = true - test("basic stream+rdd recovery") { + test("basic rdd checkpoints + dstream graph checkpoint recovery") { assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second") - assert(checkpointInterval === batchDuration, "checkpointInterval for this test much be same as batchDuration") System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") val stateStreamCheckpointInterval = Seconds(1) // this ensure checkpointing occurs at least once - val firstNumBatches = (stateStreamCheckpointInterval / batchDuration) * 2 + val firstNumBatches = (stateStreamCheckpointInterval / batchDuration).toLong * 2 val secondNumBatches = firstNumBatches // Setup the streams @@ -62,10 +70,10 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // Run till a time such that at least one RDD in the stream should have been checkpointed, // then check whether some RDD has been checkpointed or not ssc.start() - runStreamsWithRealDelay(ssc, firstNumBatches) - logInfo("Checkpoint data of state stream = \n[" + stateStream.checkpointData.rdds.mkString(",\n") + "]") - assert(!stateStream.checkpointData.rdds.isEmpty, "No checkpointed RDDs in state stream before first failure") - stateStream.checkpointData.rdds.foreach { + advanceTimeWithRealDelay(ssc, firstNumBatches) + logInfo("Checkpoint data of state stream = \n" + stateStream.checkpointData) + assert(!stateStream.checkpointData.checkpointFiles.isEmpty, "No checkpointed RDDs in state stream before first failure") + stateStream.checkpointData.checkpointFiles.foreach { case (time, data) => { val file = new File(data.toString) assert(file.exists(), "Checkpoint file '" + file +"' for time " + time + " for state stream before first failure does not exist") @@ -74,8 +82,8 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // Run till a further time such that previous checkpoint files in the stream would be deleted // and check whether the earlier checkpoint files are deleted - val checkpointFiles = stateStream.checkpointData.rdds.map(x => new File(x._2.toString)) - runStreamsWithRealDelay(ssc, secondNumBatches) + val checkpointFiles = stateStream.checkpointData.checkpointFiles.map(x => new File(x._2)) + advanceTimeWithRealDelay(ssc, secondNumBatches) checkpointFiles.foreach(file => assert(!file.exists, "Checkpoint file '" + file + "' was not deleted")) ssc.stop() @@ -90,9 +98,9 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // Run one batch to generate a new checkpoint file and check whether some RDD // is present in the checkpoint data or not ssc.start() - runStreamsWithRealDelay(ssc, 1) - assert(!stateStream.checkpointData.rdds.isEmpty, "No checkpointed RDDs in state stream before second failure") - stateStream.checkpointData.rdds.foreach { + advanceTimeWithRealDelay(ssc, 1) + assert(!stateStream.checkpointData.checkpointFiles.isEmpty, "No checkpointed RDDs in state stream before second failure") + stateStream.checkpointData.checkpointFiles.foreach { case (time, data) => { val file = new File(data.toString) assert(file.exists(), @@ -111,13 +119,16 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // Adjust manual clock time as if it is being restarted after a delay System.setProperty("spark.streaming.manualClock.jump", (batchDuration.milliseconds * 7).toString) ssc.start() - runStreamsWithRealDelay(ssc, 4) + advanceTimeWithRealDelay(ssc, 4) ssc.stop() System.clearProperty("spark.streaming.manualClock.jump") ssc = null } - test("map and reduceByKey") { + // This tests whether the systm can recover from a master failure with simple + // non-stateful operations. This assumes as reliable, replayable input + // source - TestInputDStream. + test("recovery with map and reduceByKey operations") { testCheckpointedOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ), (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), @@ -126,7 +137,11 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { ) } - test("reduceByKeyAndWindowInv") { + + // This tests whether the ReduceWindowedDStream's RDD checkpoints works correctly such + // that the system can recover from a master failure. This assumes as reliable, + // replayable input source - TestInputDStream. + test("recovery with invertible reduceByKeyAndWindow operation") { val n = 10 val w = 4 val input = (1 to n).map(_ => Seq("a")).toSeq @@ -139,7 +154,11 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { testCheckpointedOperation(input, operation, output, 7) } - test("updateStateByKey") { + + // This tests whether the StateDStream's RDD checkpoints works correctly such + // that the system can recover from a master failure. This assumes as reliable, + // replayable input source - TestInputDStream. + test("recovery with updateStateByKey operation") { val input = (1 to 10).map(_ => Seq("a")).toSeq val output = (1 to 10).map(x => Seq(("a", x))).toSeq val operation = (st: DStream[String]) => { @@ -154,11 +173,126 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { testCheckpointedOperation(input, operation, output, 7) } + // This tests whether file input stream remembers what files were seen before + // the master failure and uses them again to process a large window operation. + // It also tests whether batches, whose processing was incomplete due to the + // failure, are re-processed or not. + test("recovery with file input stream") { + // Disable manual clock as FileInputDStream does not work with manual clock + val clockProperty = System.getProperty("spark.streaming.clock") + System.clearProperty("spark.streaming.clock") + + // Set up the streaming context and input streams + val testDir = Files.createTempDir() + var ssc = new StreamingContext(master, framework, Seconds(1)) + ssc.checkpoint(checkpointDir) + val fileStream = ssc.textFileStream(testDir.toString) + // Making value 3 take large time to process, to ensure that the master + // shuts down in the middle of processing the 3rd batch + val mappedStream = fileStream.map(s => { + val i = s.toInt + if (i == 3) Thread.sleep(2000) + i + }) + + // Reducing over a large window to ensure that recovery from master failure + // requires reprocessing of all the files seen before the failure + val reducedStream = mappedStream.reduceByWindow(_ + _, Seconds(30), Seconds(1)) + val outputBuffer = new ArrayBuffer[Seq[Int]] + var outputStream = new TestOutputStream(reducedStream, outputBuffer) + ssc.registerOutputStream(outputStream) + ssc.start() + + // Create files and advance manual clock to process them + //var clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + Thread.sleep(1000) + for (i <- Seq(1, 2, 3)) { + FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") + // wait to make sure that the file is written such that it gets shown in the file listings + Thread.sleep(1000) + } + logInfo("Output = " + outputStream.output.mkString(",")) + assert(outputStream.output.size > 0, "No files processed before restart") + ssc.stop() + + // Verify whether files created have been recorded correctly or not + var fileInputDStream = ssc.graph.getInputStreams().head.asInstanceOf[FileInputDStream[_, _, _]] + def recordedFiles = fileInputDStream.files.values.flatMap(x => x) + assert(!recordedFiles.filter(_.endsWith("1")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("2")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("3")).isEmpty) + + // Create files while the master is down + for (i <- Seq(4, 5, 6)) { + FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") + Thread.sleep(1000) + } + + // Recover context from checkpoint file and verify whether the files that were + // recorded before failure were saved and successfully recovered + logInfo("*********** RESTARTING ************") + ssc = new StreamingContext(checkpointDir) + fileInputDStream = ssc.graph.getInputStreams().head.asInstanceOf[FileInputDStream[_, _, _]] + assert(!recordedFiles.filter(_.endsWith("1")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("2")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("3")).isEmpty) + + // Restart stream computation + ssc.start() + for (i <- Seq(7, 8, 9)) { + FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") + Thread.sleep(1000) + } + Thread.sleep(1000) + logInfo("Output = " + outputStream.output.mkString("[", ", ", "]")) + assert(outputStream.output.size > 0, "No files processed after restart") + ssc.stop() + + // Verify whether files created while the driver was down have been recorded or not + assert(!recordedFiles.filter(_.endsWith("4")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("5")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("6")).isEmpty) + + // Verify whether new files created after recover have been recorded or not + assert(!recordedFiles.filter(_.endsWith("7")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("8")).isEmpty) + assert(!recordedFiles.filter(_.endsWith("9")).isEmpty) + + // Append the new output to the old buffer + outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[Int]] + outputBuffer ++= outputStream.output + + val expectedOutput = Seq(1, 3, 6, 10, 15, 21, 28, 36, 45) + logInfo("--------------------------------") + logInfo("output, size = " + outputBuffer.size) + outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("expected output, size = " + expectedOutput.size) + expectedOutput.foreach(x => logInfo("[" + x + "]")) + logInfo("--------------------------------") + + // Verify whether all the elements received are as expected + val output = outputBuffer.flatMap(x => x) + assert(output.contains(6)) // To ensure that the 3rd input (i.e., 3) was processed + output.foreach(o => // To ensure all the inputs are correctly added cumulatively + assert(expectedOutput.contains(o), "Expected value " + o + " not found") + ) + // To ensure that all the inputs were received correctly + assert(expectedOutput.last === output.last) + + // Enable manual clock back again for other tests + if (clockProperty != null) + System.setProperty("spark.streaming.clock", clockProperty) + } + + /** - * Tests a streaming operation under checkpointing, by restart the operation + * Tests a streaming operation under checkpointing, by restarting the operation * from checkpoint file and verifying whether the final output is correct. * The output is assumed to have come from a reliable queue which an replay * data as required. + * + * NOTE: This takes into consideration that the last batch processed before + * master failure will be re-processed after restart/recovery. */ def testCheckpointedOperation[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], @@ -172,11 +306,14 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { val totalNumBatches = input.size val nextNumBatches = totalNumBatches - initialNumBatches val initialNumExpectedOutputs = initialNumBatches - val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1 + // because the last batch will be processed again // Do the computation for initial number of batches, create checkpoint file and quit ssc = setupStreams[U, V](input, operation) - val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs) + ssc.start() + val output = advanceTimeWithRealDelay[V](ssc, initialNumBatches) + ssc.stop() verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) Thread.sleep(1000) @@ -187,16 +324,20 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { "\n-------------------------------------------\n" ) ssc = new StreamingContext(checkpointDir) - val outputNew = runStreams[V](ssc, nextNumBatches, nextNumExpectedOutputs) + System.clearProperty("spark.driver.port") + ssc.start() + val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches) + // the first element will be re-processed data of the last batch before restart verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) + ssc.stop() ssc = null } /** * Advances the manual clock on the streaming scheduler by given number of batches. - * It also wait for the expected amount of time for each batch. + * It also waits for the expected amount of time for each batch. */ - def runStreamsWithRealDelay(ssc: StreamingContext, numBatches: Long) { + def advanceTimeWithRealDelay[V: ClassManifest](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] logInfo("Manual clock before advancing = " + clock.time) for (i <- 1 to numBatches.toInt) { @@ -205,6 +346,8 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { } logInfo("Manual clock after advancing = " + clock.time) Thread.sleep(batchDuration.milliseconds) - } + val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] + outputStream.output + } } \ No newline at end of file diff --git a/streaming/src/test/scala/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/spark/streaming/FailureSuite.scala index c4cfffbfc1ef3cf8f8c79a4963525a0ee936f28f..a5fa7ab92dd6bf9564e1f33aba39d82a51deda04 100644 --- a/streaming/src/test/scala/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/spark/streaming/FailureSuite.scala @@ -1,191 +1,40 @@ package spark.streaming -import org.scalatest.BeforeAndAfter -import org.apache.commons.io.FileUtils +import spark.Logging +import spark.streaming.util.MasterFailureTest +import StreamingContext._ + +import org.scalatest.{FunSuite, BeforeAndAfter} +import com.google.common.io.Files import java.io.File -import scala.runtime.RichInt -import scala.util.Random -import spark.streaming.StreamingContext._ +import org.apache.commons.io.FileUtils import collection.mutable.ArrayBuffer -import spark.Logging + /** * This testsuite tests master failures at random times while the stream is running using * the real clock. */ -class FailureSuite extends TestSuiteBase with BeforeAndAfter { +class FailureSuite extends FunSuite with BeforeAndAfter with Logging { + + var directory = "FailureSuite" + val numBatches = 30 + val batchDuration = Milliseconds(1000) before { - FileUtils.deleteDirectory(new File(checkpointDir)) + FileUtils.deleteDirectory(new File(directory)) } after { - FailureSuite.reset() - FileUtils.deleteDirectory(new File(checkpointDir)) - - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown - System.clearProperty("spark.driver.port") - } - - override def framework = "CheckpointSuite" - - override def batchDuration = Milliseconds(500) - - override def checkpointDir = "checkpoint" - - override def checkpointInterval = batchDuration - - test("multiple failures with updateStateByKey") { - val n = 30 - // Input: time=1 ==> [ a ] , time=2 ==> [ a, a ] , time=3 ==> [ a, a, a ] , ... - val input = (1 to n).map(i => (1 to i).map(_ =>"a").toSeq).toSeq - // Last output: [ (a, 465) ] for n=30 - val lastOutput = Seq( ("a", (1 to n).reduce(_ + _)) ) - - val operation = (st: DStream[String]) => { - val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { - Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) - } - st.map(x => (x, 1)) - .updateStateByKey[RichInt](updateFunc) - .checkpoint(Seconds(2)) - .map(t => (t._1, t._2.self)) - } - - testOperationWithMultipleFailures(input, operation, lastOutput, n, n) - } - - test("multiple failures with reduceByKeyAndWindow") { - val n = 30 - val w = 100 - assert(w > n, "Window should be much larger than the number of input sets in this test") - // Input: time=1 ==> [ a ] , time=2 ==> [ a, a ] , time=3 ==> [ a, a, a ] , ... - val input = (1 to n).map(i => (1 to i).map(_ =>"a").toSeq).toSeq - // Last output: [ (a, 465) ] - val lastOutput = Seq( ("a", (1 to n).reduce(_ + _)) ) - - val operation = (st: DStream[String]) => { - st.map(x => (x, 1)) - .reduceByKeyAndWindow(_ + _, _ - _, batchDuration * w, batchDuration) - .checkpoint(Seconds(2)) - } - - testOperationWithMultipleFailures(input, operation, lastOutput, n, n) - } - - - /** - * Tests stream operation with multiple master failures, and verifies whether the - * final set of output values is as expected or not. Checking the final value is - * proof that no intermediate data was lost due to master failures. - */ - def testOperationWithMultipleFailures[U: ClassManifest, V: ClassManifest]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V], - lastExpectedOutput: Seq[V], - numBatches: Int, - numExpectedOutput: Int - ) { - var ssc = setupStreams[U, V](input, operation) - val mergedOutput = new ArrayBuffer[Seq[V]]() - - var totalTimeRan = 0L - while(totalTimeRan <= numBatches * batchDuration.milliseconds * 2) { - new KillingThread(ssc, numBatches * batchDuration.milliseconds.toInt / 4).start() - val (output, timeRan) = runStreamsWithRealClock[V](ssc, numBatches, numExpectedOutput) - - mergedOutput ++= output - totalTimeRan += timeRan - logInfo("New output = " + output) - logInfo("Merged output = " + mergedOutput) - logInfo("Total time spent = " + totalTimeRan) - val sleepTime = Random.nextInt(numBatches * batchDuration.milliseconds.toInt / 8) - logInfo( - "\n-------------------------------------------\n" + - " Restarting stream computation in " + sleepTime + " ms " + - "\n-------------------------------------------\n" - ) - Thread.sleep(sleepTime) - FailureSuite.failed = false - ssc = new StreamingContext(checkpointDir) - } - ssc.stop() - ssc = null - - // Verify whether the last output is the expected one - val lastOutput = mergedOutput(mergedOutput.lastIndexWhere(!_.isEmpty)) - assert(lastOutput.toSet === lastExpectedOutput.toSet) - logInfo("Finished computation after " + FailureSuite.failureCount + " failures") + FileUtils.deleteDirectory(new File(directory)) } - /** - * Runs the streams set up in `ssc` on real clock until the expected max number of - */ - def runStreamsWithRealClock[V: ClassManifest]( - ssc: StreamingContext, - numBatches: Int, - maxExpectedOutput: Int - ): (Seq[Seq[V]], Long) = { - - System.clearProperty("spark.streaming.clock") - - assert(numBatches > 0, "Number of batches to run stream computation is zero") - assert(maxExpectedOutput > 0, "Max expected outputs after " + numBatches + " is zero") - logInfo("numBatches = " + numBatches + ", maxExpectedOutput = " + maxExpectedOutput) - - // Get the output buffer - val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] - val output = outputStream.output - val waitTime = (batchDuration.milliseconds * (numBatches.toDouble + 0.5)).toLong - val startTime = System.currentTimeMillis() - - try { - // Start computation - ssc.start() - - // Wait until expected number of output items have been generated - while (output.size < maxExpectedOutput && System.currentTimeMillis() - startTime < waitTime && !FailureSuite.failed) { - logInfo("output.size = " + output.size + ", maxExpectedOutput = " + maxExpectedOutput) - Thread.sleep(100) - } - } catch { - case e: Exception => logInfo("Exception while running streams: " + e) - } finally { - ssc.stop() - } - val timeTaken = System.currentTimeMillis() - startTime - logInfo("" + output.size + " sets of output generated in " + timeTaken + " ms") - (output, timeTaken) + test("multiple failures with map") { + MasterFailureTest.testMap(directory, numBatches, batchDuration) } - -} - -object FailureSuite { - var failed = false - var failureCount = 0 - - def reset() { - failed = false - failureCount = 0 + test("multiple failures with updateStateByKey") { + MasterFailureTest.testUpdateStateByKey(directory, numBatches, batchDuration) } } -class KillingThread(ssc: StreamingContext, maxKillWaitTime: Int) extends Thread with Logging { - initLogging() - - override def run() { - var minKillWaitTime = if (FailureSuite.failureCount == 0) 3000 else 1000 // to allow the first checkpoint - val killWaitTime = minKillWaitTime + Random.nextInt(maxKillWaitTime) - logInfo("Kill wait time = " + killWaitTime) - Thread.sleep(killWaitTime.toLong) - logInfo( - "\n---------------------------------------\n" + - "Killing streaming context after " + killWaitTime + " ms" + - "\n---------------------------------------\n" - ) - if (ssc != null) ssc.stop() - FailureSuite.failed = true - FailureSuite.failureCount += 1 - } -} diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index 70ae6e3934cfd3403e47903a3adaa46d024fddcb..1024d3ac9790e801a060d90a9faab1d64237515b 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -1,5 +1,11 @@ package spark.streaming +import akka.actor.Actor +import akka.actor.IO +import akka.actor.IOManager +import akka.actor.Props +import akka.util.ByteString + import dstream.SparkFlumeEvent import java.net.{InetSocketAddress, SocketException, Socket, ServerSocket} import java.io.{File, BufferedWriter, OutputStreamWriter} @@ -7,6 +13,7 @@ import java.util.concurrent.{TimeUnit, ArrayBlockingQueue} import collection.mutable.{SynchronizedBuffer, ArrayBuffer} import util.ManualClock import spark.storage.StorageLevel +import spark.streaming.receivers.Receiver import spark.Logging import scala.util.Random import org.apache.commons.io.FileUtils @@ -19,40 +26,30 @@ import org.apache.avro.ipc.specific.SpecificRequestor import java.nio.ByteBuffer import collection.JavaConversions._ import java.nio.charset.Charset +import com.google.common.io.Files class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") val testPort = 9999 - var testServer: TestServer = null - var testDir: File = null override def checkpointDir = "checkpoint" after { - FileUtils.deleteDirectory(new File(checkpointDir)) - if (testServer != null) { - testServer.stop() - testServer = null - } - if (testDir != null && testDir.exists()) { - FileUtils.deleteDirectory(testDir) - testDir = null - } - // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown System.clearProperty("spark.driver.port") } - test("network input stream") { + + test("socket input stream") { // Start the server - testServer = new TestServer(testPort) + val testServer = new TestServer(testPort) testServer.start() // Set up the streaming context and input streams val ssc = new StreamingContext(master, framework, batchDuration) - val networkStream = ssc.networkTextStream("localhost", testPort, StorageLevel.MEMORY_AND_DISK) + val networkStream = ssc.socketTextStream("localhost", testPort, StorageLevel.MEMORY_AND_DISK) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String ]] val outputStream = new TestOutputStream(networkStream, outputBuffer) def output = outputBuffer.flatMap(x => x) @@ -93,46 +90,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } } - test("network input stream with checkpoint") { - // Start the server - testServer = new TestServer(testPort) - testServer.start() - - // Set up the streaming context and input streams - var ssc = new StreamingContext(master, framework, batchDuration) - ssc.checkpoint(checkpointDir, checkpointInterval) - val networkStream = ssc.networkTextStream("localhost", testPort, StorageLevel.MEMORY_AND_DISK) - var outputStream = new TestOutputStream(networkStream, new ArrayBuffer[Seq[String]]) - ssc.registerOutputStream(outputStream) - ssc.start() - - // Feed data to the server to send to the network receiver - var clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - for (i <- Seq(1, 2, 3)) { - testServer.send(i.toString + "\n") - Thread.sleep(100) - clock.addToTime(batchDuration.milliseconds) - } - Thread.sleep(500) - assert(outputStream.output.size > 0) - ssc.stop() - - // Restart stream computation from checkpoint and feed more data to see whether - // they are being received and processed - logInfo("*********** RESTARTING ************") - ssc = new StreamingContext(checkpointDir) - ssc.start() - clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - for (i <- Seq(4, 5, 6)) { - testServer.send(i.toString + "\n") - Thread.sleep(100) - clock.addToTime(batchDuration.milliseconds) - } - Thread.sleep(500) - outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[String]] - assert(outputStream.output.size > 0) - ssc.stop() - } test("flume input stream") { // Set up the streaming context and input streams @@ -146,7 +103,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq(1, 2, 3, 4, 5) - + Thread.sleep(1000) val transceiver = new NettyTransceiver(new InetSocketAddress("localhost", 33333)); val client = SpecificRequestor.getClient( classOf[AvroSourceProtocol], transceiver); @@ -182,42 +139,33 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } } - test("file input stream") { - // Create a temporary directory - testDir = { - var temp = File.createTempFile(".temp.", Random.nextInt().toString) - temp.delete() - temp.mkdirs() - logInfo("Created temp dir " + temp) - temp - } + test("file input stream") { + // Disable manual clock as FileInputDStream does not work with manual clock + System.clearProperty("spark.streaming.clock") // Set up the streaming context and input streams + val testDir = Files.createTempDir() val ssc = new StreamingContext(master, framework, batchDuration) - val filestream = ssc.textFileStream(testDir.toString) + val fileStream = ssc.textFileStream(testDir.toString) val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] def output = outputBuffer.flatMap(x => x) - val outputStream = new TestOutputStream(filestream, outputBuffer) + val outputStream = new TestOutputStream(fileStream, outputBuffer) ssc.registerOutputStream(outputStream) ssc.start() // Create files in the temporary directory so that Spark Streaming can read data from it - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq(1, 2, 3, 4, 5) val expectedOutput = input.map(_.toString) Thread.sleep(1000) for (i <- 0 until input.size) { - FileUtils.writeStringToFile(new File(testDir, i.toString), input(i).toString + "\n") - Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) - //Thread.sleep(100) + val file = new File(testDir, i.toString) + FileUtils.writeStringToFile(file, input(i).toString + "\n") + logInfo("Created file " + file) + Thread.sleep(batchDuration.milliseconds) + Thread.sleep(1000) } val startTime = System.currentTimeMillis() - /*while (output.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { - logInfo("output.size = " + output.size + ", expectedOutput.size = " + expectedOutput.size) - Thread.sleep(100) - }*/ Thread.sleep(1000) val timeTaken = System.currentTimeMillis() - startTime assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") @@ -226,75 +174,76 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Verify whether data received by Spark Streaming was as expected logInfo("--------------------------------") - logInfo("output.size = " + outputBuffer.size) - logInfo("output") + logInfo("output, size = " + outputBuffer.size) outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) - logInfo("expected output.size = " + expectedOutput.size) - logInfo("expected output") + logInfo("expected output, size = " + expectedOutput.size) expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) logInfo("--------------------------------") // Verify whether all the elements received are as expected // (whether the elements were received one in each interval is not verified) - assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { - assert(output(i).size === 1) - assert(output(i).head.toString === expectedOutput(i)) - } + assert(output.toList === expectedOutput.toList) + + FileUtils.deleteDirectory(testDir) + + // Enable manual clock back again for other tests + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") } - test("file input stream with checkpoint") { - // Create a temporary directory - testDir = { - var temp = File.createTempFile(".temp.", Random.nextInt().toString) - temp.delete() - temp.mkdirs() - logInfo("Created temp dir " + temp) - temp - } + + test("actor input stream") { + // Start the server + val port = testPort + val testServer = new TestServer(port) + testServer.start() // Set up the streaming context and input streams - var ssc = new StreamingContext(master, framework, batchDuration) - ssc.checkpoint(checkpointDir, checkpointInterval) - val filestream = ssc.textFileStream(testDir.toString) - var outputStream = new TestOutputStream(filestream, new ArrayBuffer[Seq[String]]) + val ssc = new StreamingContext(master, framework, batchDuration) + val networkStream = ssc.actorStream[String](Props(new TestActor(port)), "TestActor", + StorageLevel.MEMORY_AND_DISK) //Had to pass the local value of port to prevent from closing over entire scope + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + val outputStream = new TestOutputStream(networkStream, outputBuffer) + def output = outputBuffer.flatMap(x => x) ssc.registerOutputStream(outputStream) ssc.start() - // Create files and advance manual clock to process them - var clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + // Feed data to the server to send to the network receiver + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val input = 1 to 9 + val expectedOutput = input.map(x => x.toString) Thread.sleep(1000) - for (i <- Seq(1, 2, 3)) { - FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") - Thread.sleep(100) + for (i <- 0 until input.size) { + testServer.send(input(i).toString) + Thread.sleep(500) clock.addToTime(batchDuration.milliseconds) } - Thread.sleep(500) - logInfo("Output = " + outputStream.output.mkString(",")) - assert(outputStream.output.size > 0) + Thread.sleep(1000) + logInfo("Stopping server") + testServer.stop() + logInfo("Stopping context") ssc.stop() - // Restart stream computation from checkpoint and create more files to see whether - // they are being processed - logInfo("*********** RESTARTING ************") - ssc = new StreamingContext(checkpointDir) - ssc.start() - clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - Thread.sleep(500) - for (i <- Seq(4, 5, 6)) { - FileUtils.writeStringToFile(new File(testDir, i.toString), i.toString + "\n") - Thread.sleep(100) - clock.addToTime(batchDuration.milliseconds) + // Verify whether data received was as expected + logInfo("--------------------------------") + logInfo("output.size = " + outputBuffer.size) + logInfo("output") + outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("expected output.size = " + expectedOutput.size) + logInfo("expected output") + expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("--------------------------------") + + // Verify whether all the elements received are as expected + // (whether the elements were received one in each interval is not verified) + assert(output.size === expectedOutput.size) + for (i <- 0 until output.size) { + assert(output(i) === expectedOutput(i)) } - Thread.sleep(500) - outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[String]] - logInfo("Output = " + outputStream.output.mkString(",")) - assert(outputStream.output.size > 0) - ssc.stop() } } +/** This is server to test the network input stream */ class TestServer(port: Int) extends Logging { val queue = new ArrayBlockingQueue[String](100) @@ -353,3 +302,15 @@ object TestServer { } } } + +class TestActor(port: Int) extends Actor with Receiver { + + def bytesToString(byteString: ByteString) = byteString.utf8String + + override def preStart = IOManager(context.system).connect(new InetSocketAddress(port)) + + def receive = { + case IO.Read(socket, bytes) => + pushBlock(bytesToString(bytes)) + } +} diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index 49129f39640c4d906dee6604fb0b7e1d84e76fee..ad6aa79d102339bc0595d30463d47bdf6ee183fd 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -28,6 +28,11 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[ logInfo("Computing RDD for time " + validTime) val index = ((validTime - zeroTime) / slideDuration - 1).toInt val selectedInput = if (index < input.size) input(index) else Seq[T]() + + // lets us test cases where RDDs are not created + if (selectedInput == null) + return None + val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) logInfo("Created RDD " + rdd.id + " with " + selectedInput) Some(rdd) @@ -58,20 +63,25 @@ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBu */ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { + // Name of the framework for Spark context def framework = "TestSuiteBase" + // Master for Spark context def master = "local[2]" + // Batch duration def batchDuration = Seconds(1) + // Directory where the checkpoint data will be saved def checkpointDir = "checkpoint" - def checkpointInterval = batchDuration - + // Number of partitions of the input parallel collections created for testing def numInputPartitions = 2 + // Maximum time to wait before the test times out def maxWaitTimeMillis = 10000 + // Whether to actually wait in real time before changing manual clock def actuallyWait = false /** @@ -86,7 +96,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Create StreamingContext val ssc = new StreamingContext(master, framework, batchDuration) if (checkpointDir != null) { - ssc.checkpoint(checkpointDir, checkpointInterval) + ssc.checkpoint(checkpointDir) } // Setup the stream computation @@ -111,7 +121,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Create StreamingContext val ssc = new StreamingContext(master, framework, batchDuration) if (checkpointDir != null) { - ssc.checkpoint(checkpointDir, checkpointInterval) + ssc.checkpoint(checkpointDir) } // Setup the stream computation @@ -135,9 +145,6 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { numBatches: Int, numExpectedOutput: Int ): Seq[Seq[V]] = { - - System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") - assert(numBatches > 0, "Number of batches to run stream computation is zero") assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero") logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput) @@ -181,7 +188,6 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { } finally { ssc.stop() } - output } diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index cd9608df530af437c42befac4589ad7784ed09a0..1b66f3bda20ad0f112295e529f4a3f1419c167ed 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -5,6 +5,8 @@ import collection.mutable.ArrayBuffer class WindowOperationsSuite extends TestSuiteBase { + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + override def framework = "WindowOperationsSuite" override def maxWaitTimeMillis = 20000 @@ -82,12 +84,9 @@ class WindowOperationsSuite extends TestSuiteBase { ) /* - The output of the reduceByKeyAndWindow with inverse reduce function is - different from the naive reduceByKeyAndWindow. Even if the count of a - particular key is 0, the key does not get eliminated from the RDDs of - ReducedWindowedDStream. This causes the number of keys in these RDDs to - increase forever. A more generalized version that allows elimination of - keys should be considered. + The output of the reduceByKeyAndWindow with inverse function but without a filter + function will be different from the naive reduceByKeyAndWindow, as no keys get + eliminated from the ReducedWindowedDStream even if the value of a key becomes 0. */ val bigReduceInvOutput = Seq( @@ -175,31 +174,31 @@ class WindowOperationsSuite extends TestSuiteBase { // Testing reduceByKeyAndWindow (with invertible reduce function) - testReduceByKeyAndWindowInv( + testReduceByKeyAndWindowWithInverse( "basic reduction", Seq(Seq(("a", 1), ("a", 3)) ), Seq(Seq(("a", 4)) ) ) - testReduceByKeyAndWindowInv( + testReduceByKeyAndWindowWithInverse( "key already in window and new value added into window", Seq( Seq(("a", 1)), Seq(("a", 1)) ), Seq( Seq(("a", 1)), Seq(("a", 2)) ) ) - testReduceByKeyAndWindowInv( + testReduceByKeyAndWindowWithInverse( "new key added into window", Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ), Seq( Seq(("a", 1)), Seq(("a", 2), ("b", 1)) ) ) - testReduceByKeyAndWindowInv( + testReduceByKeyAndWindowWithInverse( "key removed from window", Seq( Seq(("a", 1)), Seq(("a", 1)), Seq(), Seq() ), Seq( Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 1)), Seq(("a", 0)) ) ) - testReduceByKeyAndWindowInv( + testReduceByKeyAndWindowWithInverse( "larger slide time", largerSlideInput, largerSlideReduceOutput, @@ -207,7 +206,9 @@ class WindowOperationsSuite extends TestSuiteBase { Seconds(2) ) - testReduceByKeyAndWindowInv("big test", bigInput, bigReduceInvOutput) + testReduceByKeyAndWindowWithInverse("big test", bigInput, bigReduceInvOutput) + + testReduceByKeyAndWindowWithFilteredInverse("big test", bigInput, bigReduceOutput) test("groupByKeyAndWindow") { val input = bigInput @@ -235,14 +236,14 @@ class WindowOperationsSuite extends TestSuiteBase { testOperation(input, operation, expectedOutput, numBatches, true) } - test("countByKeyAndWindow") { - val input = Seq(Seq(("a", 1)), Seq(("b", 1), ("b", 2)), Seq(("a", 10), ("b", 20))) + test("countByValueAndWindow") { + val input = Seq(Seq("a"), Seq("b", "b"), Seq("a", "b")) val expectedOutput = Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 2)), Seq(("a", 1), ("b", 3))) val windowDuration = Seconds(2) val slideDuration = Seconds(1) val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt - val operation = (s: DStream[(String, Int)]) => { - s.countByKeyAndWindow(windowDuration, slideDuration).map(x => (x._1, x._2.toInt)) + val operation = (s: DStream[String]) => { + s.countByValueAndWindow(windowDuration, slideDuration).map(x => (x._1, x._2.toInt)) } testOperation(input, operation, expectedOutput, numBatches, true) } @@ -272,29 +273,50 @@ class WindowOperationsSuite extends TestSuiteBase { slideDuration: Duration = Seconds(1) ) { test("reduceByKeyAndWindow - " + name) { + logInfo("reduceByKeyAndWindow - " + name) val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt val operation = (s: DStream[(String, Int)]) => { - s.reduceByKeyAndWindow(_ + _, windowDuration, slideDuration).persist() + s.reduceByKeyAndWindow((x: Int, y: Int) => x + y, windowDuration, slideDuration) } testOperation(input, operation, expectedOutput, numBatches, true) } } - def testReduceByKeyAndWindowInv( + def testReduceByKeyAndWindowWithInverse( name: String, input: Seq[Seq[(String, Int)]], expectedOutput: Seq[Seq[(String, Int)]], windowDuration: Duration = Seconds(2), slideDuration: Duration = Seconds(1) ) { - test("reduceByKeyAndWindowInv - " + name) { + test("reduceByKeyAndWindow with inverse function - " + name) { + logInfo("reduceByKeyAndWindow with inverse function - " + name) val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt val operation = (s: DStream[(String, Int)]) => { s.reduceByKeyAndWindow(_ + _, _ - _, windowDuration, slideDuration) - .persist() .checkpoint(Seconds(100)) // Large value to avoid effect of RDD checkpointing } testOperation(input, operation, expectedOutput, numBatches, true) } } + + def testReduceByKeyAndWindowWithFilteredInverse( + name: String, + input: Seq[Seq[(String, Int)]], + expectedOutput: Seq[Seq[(String, Int)]], + windowDuration: Duration = Seconds(2), + slideDuration: Duration = Seconds(1) + ) { + test("reduceByKeyAndWindow with inverse and filter functions - " + name) { + logInfo("reduceByKeyAndWindow with inverse and filter functions - " + name) + val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt + val filterFunc = (p: (String, Int)) => p._2 != 0 + val operation = (s: DStream[(String, Int)]) => { + s.reduceByKeyAndWindow(_ + _, _ - _, windowDuration, slideDuration, filterFunc = filterFunc) + .persist() + .checkpoint(Seconds(100)) // Large value to avoid effect of RDD checkpointing + } + testOperation(input, operation, expectedOutput, numBatches, true) + } + } }