diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 5fd1fab5809da1ddc9cf6d2162c73a9532a53cf5..f9b6ee351a151cefcc310d17d50249fdc20b1d5b 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -48,6 +48,19 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav */ def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel)) + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * This method blocks until all blocks are deleted. + */ + def unpersist(): JavaDoubleRDD = fromRDD(srdd.unpersist()) + + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * + * @param blocking Whether to block until all blocks are deleted. + */ + def unpersist(blocking: Boolean): JavaDoubleRDD = fromRDD(srdd.unpersist(blocking)) + // first() has to be overriden here in order for its return type to be Double instead of Object. override def first(): Double = srdd.first() diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index c099ca77b949a795c047179c30e8460dab3aa510..b3eb739f4e701617ab8ad9d64589f140d73dbc7a 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -65,6 +65,19 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif def persist(newLevel: StorageLevel): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.persist(newLevel)) + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * This method blocks until all blocks are deleted. + */ + def unpersist(): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist()) + + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * + * @param blocking Whether to block until all blocks are deleted. + */ + def unpersist(blocking: Boolean): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist(blocking)) + // Transformations (return a new RDD) /** diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index eec58abdd608a462997b6101ac645d8c8be9cefe..662990049b0938f542e5966c992a2bd1554a5586 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -41,9 +41,17 @@ JavaRDDLike[T, JavaRDD[T]] { /** * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * This method blocks until all blocks are deleted. */ def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist()) + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * + * @param blocking Whether to block until all blocks are deleted. + */ + def unpersist(blocking: Boolean): JavaRDD[T] = wrapRDD(rdd.unpersist(blocking)) + // Transformations (return a new RDD) /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index dda194d9537c879d630ceba631d9301990d9c146..4cef0825dd6c0aab711df8a58700bd37fb91c0e0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -68,6 +68,11 @@ class DAGScheduler( eventQueue.put(BeginEvent(task, taskInfo)) } + // Called to report that a task has completed and results are being fetched remotely. + def taskGettingResult(task: Task[_], taskInfo: TaskInfo) { + eventQueue.put(GettingResultEvent(task, taskInfo)) + } + // Called by TaskScheduler to report task completions or failures. def taskEnded( task: Task[_], @@ -415,6 +420,9 @@ class DAGScheduler( case begin: BeginEvent => listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo)) + case gettingResult: GettingResultEvent => + listenerBus.post(SparkListenerTaskGettingResult(gettingResult.task, gettingResult.taskInfo)) + case completion: CompletionEvent => listenerBus.post(SparkListenerTaskEnd( completion.task, completion.reason, completion.taskInfo, completion.taskMetrics)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index a5769c604195b572196c4171141cdd95a5f81bad..708d221d60caf8cd981780d513a89d372a1baf97 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -53,6 +53,9 @@ private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent private[scheduler] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent +private[scheduler] +case class GettingResultEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent + private[scheduler] case class CompletionEvent( task: Task[_], reason: TaskEndReason, diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 324cd639b0a5710a74006ef19963903800dff521..a35081f7b10d7040d8b45302ce50941cef3e7960 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -31,6 +31,9 @@ case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents +case class SparkListenerTaskGettingResult( + task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents + case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo, taskMetrics: TaskMetrics) extends SparkListenerEvents @@ -56,6 +59,12 @@ trait SparkListener { */ def onTaskStart(taskStart: SparkListenerTaskStart) { } + /** + * Called when a task begins remotely fetching its result (will not be called for tasks that do + * not need to fetch the result remotely). + */ + def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { } + /** * Called when a task ends */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 4d3e4a17ba5620281d0b0e7238fd39907d814e46..d5824e79547974e643b348b12465fa6fe78a2fe0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -49,6 +49,8 @@ private[spark] class SparkListenerBus() extends Logging { sparkListeners.foreach(_.onJobEnd(jobEnd)) case taskStart: SparkListenerTaskStart => sparkListeners.foreach(_.onTaskStart(taskStart)) + case taskGettingResult: SparkListenerTaskGettingResult => + sparkListeners.foreach(_.onTaskGettingResult(taskGettingResult)) case taskEnd: SparkListenerTaskEnd => sparkListeners.foreach(_.onTaskEnd(taskEnd)) case _ => diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 7c2a422affbbfbe4be65817121d22c06f1bb3dfd..4bae26f3a6a885c73bd1639d61d226cbd06a5ea2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -31,9 +31,25 @@ class TaskInfo( val host: String, val taskLocality: TaskLocality.TaskLocality) { + /** + * The time when the task started remotely getting the result. Will not be set if the + * task result was sent immediately when the task finished (as opposed to sending an + * IndirectTaskResult and later fetching the result from the block manager). + */ + var gettingResultTime: Long = 0 + + /** + * The time when the task has completed successfully (including the time to remotely fetch + * results, if necessary). + */ var finishTime: Long = 0 + var failed = false + def markGettingResult(time: Long = System.currentTimeMillis) { + gettingResultTime = time + } + def markSuccessful(time: Long = System.currentTimeMillis) { finishTime = time } @@ -43,6 +59,8 @@ class TaskInfo( failed = true } + def gettingResult: Boolean = gettingResultTime != 0 + def finished: Boolean = finishTime != 0 def successful: Boolean = finished && !failed @@ -52,6 +70,8 @@ class TaskInfo( def status: String = { if (running) "RUNNING" + else if (gettingResult) + "GET RESULT" else if (failed) "FAILED" else if (successful) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala index 4ea8bf88534cf1d90b691b5265101a9e235eaf5d..85033958ef54f4e1568a3023630621e2f9cd7b35 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala @@ -306,6 +306,10 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } + def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) { + taskSetManager.handleTaskGettingResult(tid) + } + def handleSuccessfulTask( taskSetManager: ClusterTaskSetManager, tid: Long, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala index 29093e3b4f511b1fe9d1e4fcadc9d20ac6e3d1bc..ee47aaffcae11d1e341626791047c5e1ae2bc9ca 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -418,6 +418,12 @@ private[spark] class ClusterTaskSetManager( sched.dagScheduler.taskStarted(task, info) } + def handleTaskGettingResult(tid: Long) = { + val info = taskInfos(tid) + info.markGettingResult() + sched.dagScheduler.taskGettingResult(tasks(info.index), info) + } + /** * Marks the task as successful and notifies the DAGScheduler that a task has ended. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala index 4312c46cc190c1279318942f0c150e743c36fe14..2064d97b49cc04f35cd638a65c57112abf3d9956 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala @@ -50,6 +50,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche case directResult: DirectTaskResult[_] => directResult case IndirectTaskResult(blockId) => logDebug("Fetching indirect task result for TID %s".format(tid)) + scheduler.handleTaskGettingResult(taskSetManager, tid) val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId) if (!serializedTaskResult.isDefined) { /* We won't be able to get the task result if the machine that ran the task failed diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 9bb8a13ec45d86c5dda068aea29db989acfc8dcf..6b854740d6a2425e51cac0505a2e6ecc58cecf30 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -115,7 +115,13 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList taskList += ((taskStart.taskInfo, None, None)) stageIdToTaskInfos(sid) = taskList } - + + override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) + = synchronized { + // Do nothing: because we don't do a deep copy of the TaskInfo, the TaskInfo in + // stageToTaskInfos already has the updated status. + } + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { val sid = taskEnd.task.stageId val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 42ca988f7a12995b2e6f322e587cfae7932b9b39..f7f599532a96c3c61ec8a1ed46359dca8a2f5b90 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -17,22 +17,25 @@ package org.apache.spark.scheduler -import org.scalatest.{BeforeAndAfter, FunSuite} -import org.apache.spark.{LocalSparkContext, SparkContext} -import scala.collection.mutable +import scala.collection.mutable.{Buffer, HashSet} + +import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.scalatest.matchers.ShouldMatchers + +import org.apache.spark.{LocalSparkContext, SparkContext} import org.apache.spark.SparkContext._ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers - with BeforeAndAfter { + with BeforeAndAfterAll { /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 - before { - sc = new SparkContext("local", "DAGSchedulerSuite") + override def afterAll { + System.clearProperty("spark.akka.frameSize") } test("basic creation of StageInfo") { + sc = new SparkContext("local", "DAGSchedulerSuite") val listener = new SaveStageInfo sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) @@ -53,6 +56,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } test("StageInfo with fewer tasks than partitions") { + sc = new SparkContext("local", "DAGSchedulerSuite") val listener = new SaveStageInfo sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) @@ -68,6 +72,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } test("local metrics") { + sc = new SparkContext("local", "DAGSchedulerSuite") val listener = new SaveStageInfo sc.addSparkListener(listener) sc.addSparkListener(new StatsReportListener) @@ -129,15 +134,73 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } } + test("onTaskGettingResult() called when result fetched remotely") { + // Need to use local cluster mode here, because results are not ever returned through the + // block manager when using the LocalScheduler. + sc = new SparkContext("local-cluster[1,1,512]", "test") + + val listener = new SaveTaskEvents + sc.addSparkListener(listener) + + // Make a task whose result is larger than the akka frame size + System.setProperty("spark.akka.frameSize", "1") + val akkaFrameSize = + sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt + val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x,y) => x) + assert(result === 1.to(akkaFrameSize).toArray) + + assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + val TASK_INDEX = 0 + assert(listener.startedTasks.contains(TASK_INDEX)) + assert(listener.startedGettingResultTasks.contains(TASK_INDEX)) + assert(listener.endedTasks.contains(TASK_INDEX)) + } + + test("onTaskGettingResult() not called when result sent directly") { + // Need to use local cluster mode here, because results are not ever returned through the + // block manager when using the LocalScheduler. + sc = new SparkContext("local-cluster[1,1,512]", "test") + + val listener = new SaveTaskEvents + sc.addSparkListener(listener) + + // Make a task whose result is larger than the akka frame size + val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x) + assert(result === 2) + + assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + val TASK_INDEX = 0 + assert(listener.startedTasks.contains(TASK_INDEX)) + assert(listener.startedGettingResultTasks.isEmpty == true) + assert(listener.endedTasks.contains(TASK_INDEX)) + } + def checkNonZeroAvg(m: Traversable[Long], msg: String) { assert(m.sum / m.size.toDouble > 0.0, msg) } class SaveStageInfo extends SparkListener { - val stageInfos = mutable.Buffer[StageInfo]() + val stageInfos = Buffer[StageInfo]() override def onStageCompleted(stage: StageCompleted) { stageInfos += stage.stage } } + class SaveTaskEvents extends SparkListener { + val startedTasks = new HashSet[Int]() + val startedGettingResultTasks = new HashSet[Int]() + val endedTasks = new HashSet[Int]() + + override def onTaskStart(taskStart: SparkListenerTaskStart) { + startedTasks += taskStart.taskInfo.index + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + endedTasks += taskEnd.taskInfo.index + } + + override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { + startedGettingResultTasks += taskGettingResult.taskInfo.index + } + } } diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala new file mode 100644 index 0000000000000000000000000000000000000000..af698a01d511871472751f1fdb779e808cb4d6ab --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.examples + +import org.apache.spark.streaming.{ Seconds, StreamingContext } +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.streaming.dstream.MQTTReceiver +import org.apache.spark.storage.StorageLevel + +import org.eclipse.paho.client.mqttv3.MqttClient +import org.eclipse.paho.client.mqttv3.MqttClientPersistence +import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence +import org.eclipse.paho.client.mqttv3.MqttException +import org.eclipse.paho.client.mqttv3.MqttMessage +import org.eclipse.paho.client.mqttv3.MqttTopic + +/** + * A simple Mqtt publisher for demonstration purposes, repeatedly publishes + * Space separated String Message "hello mqtt demo for spark streaming" + */ +object MQTTPublisher { + + var client: MqttClient = _ + + def main(args: Array[String]) { + if (args.length < 2) { + System.err.println("Usage: MQTTPublisher <MqttBrokerUrl> <topic>") + System.exit(1) + } + + val Seq(brokerUrl, topic) = args.toSeq + + try { + var peristance:MqttClientPersistence =new MqttDefaultFilePersistence("/tmp") + client = new MqttClient(brokerUrl, MqttClient.generateClientId(), peristance) + } catch { + case e: MqttException => println("Exception Caught: " + e) + } + + client.connect() + + val msgtopic: MqttTopic = client.getTopic(topic); + val msg: String = "hello mqtt demo for spark streaming" + + while (true) { + val message: MqttMessage = new MqttMessage(String.valueOf(msg).getBytes()) + msgtopic.publish(message); + println("Published data. topic: " + msgtopic.getName() + " Message: " + message) + } + client.disconnect() + } +} + +/** + * A sample wordcount with MqttStream stream + * + * To work with Mqtt, Mqtt Message broker/server required. + * Mosquitto (http://mosquitto.org/) is an open source Mqtt Broker + * In ubuntu mosquitto can be installed using the command `$ sudo apt-get install mosquitto` + * Eclipse paho project provides Java library for Mqtt Client http://www.eclipse.org/paho/ + * Example Java code for Mqtt Publisher and Subscriber can be found here https://bitbucket.org/mkjinesh/mqttclient + * Usage: MQTTWordCount <master> <MqttbrokerUrl> <topic> + * In local mode, <master> should be 'local[n]' with n > 1 + * <MqttbrokerUrl> and <topic> describe where Mqtt publisher is running. + * + * To run this example locally, you may run publisher as + * `$ ./run-example org.apache.spark.streaming.examples.MQTTPublisher tcp://localhost:1883 foo` + * and run the example as + * `$ ./run-example org.apache.spark.streaming.examples.MQTTWordCount local[2] tcp://localhost:1883 foo` + */ +object MQTTWordCount { + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println( + "Usage: MQTTWordCount <master> <MqttbrokerUrl> <topic>" + + " In local mode, <master> should be 'local[n]' with n > 1") + System.exit(1) + } + + val Seq(master, brokerUrl, topic) = args.toSeq + + val ssc = new StreamingContext(master, "MqttWordCount", Seconds(2), System.getenv("SPARK_HOME"), + Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val lines = ssc.mqttStream(brokerUrl, topic, StorageLevel.MEMORY_ONLY) + + val words = lines.flatMap(x => x.toString.split(" ")) + val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() + ssc.start() + } +} diff --git a/pom.xml b/pom.xml index 54f100c37fc2b1f69e2a717d1236c27d5672a9aa..53ac82efd0247ebf45accfdc465d3fd0377c9901 100644 --- a/pom.xml +++ b/pom.xml @@ -147,6 +147,17 @@ <enabled>false</enabled> </snapshots> </repository> + <repository> + <id>mqtt-repo</id> + <name>MQTT Repository</name> + <url>https://repo.eclipse.org/content/repositories/paho-releases/</url> + <releases> + <enabled>true</enabled> + </releases> + <snapshots> + <enabled>false</enabled> + </snapshots> + </repository> </repositories> <pluginRepositories> <pluginRepository> diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 17f480e3f01015df4befbb781e8abfaf83d6928d..8d7cbae8214e49489ae2ee673bdab74450a42d2c 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -108,7 +108,10 @@ object SparkBuild extends Build { // Shared between both core and streaming. resolvers ++= Seq("Akka Repository" at "http://repo.akka.io/releases/"), - // For Sonatype publishing + // Shared between both examples and streaming. + resolvers ++= Seq("Mqtt Repository" at "https://repo.eclipse.org/content/repositories/paho-releases/"), + + // For Sonatype publishing resolvers ++= Seq("sonatype-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots", "sonatype-staging" at "https://oss.sonatype.org/service/local/staging/deploy/maven2/"), @@ -282,10 +285,11 @@ object SparkBuild extends Build { "Apache repo" at "https://repository.apache.org/content/repositories/releases" ), libraryDependencies ++= Seq( + "org.eclipse.paho" % "mqtt-client" % "0.4.0", "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty, excludeSnappy), "org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty), "com.typesafe.akka" % "akka-zeromq" % "2.0.5" excludeAll(excludeNetty), - "org.apache.kafka" % "kafka_2.9.2" % "0.8.0-beta1" + "org.apache.kafka" % "kafka_2.9.2" % "0.8.0-beta1" exclude("com.sun.jdmk", "jmxtools") exclude("com.sun.jmx", "jmxri") ) diff --git a/streaming/pom.xml b/streaming/pom.xml index 339fcd2a391e3751b53f9dca0bb9f1458899dbf2..8022c4fe18917a1a671ceed7d779c2bd54ae6705 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -136,6 +136,11 @@ <artifactId>slf4j-log4j12</artifactId> <scope>test</scope> </dependency> + <dependency> + <groupId>org.eclipse.paho</groupId> + <artifactId>mqtt-client</artifactId> + <version>0.4.0</version> + </dependency> </dependencies> <build> <outputDirectory>target/scala-${scala.version}/classes</outputDirectory> diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 70bc25070abf199a34ecb42798ea0ab75c112b67..70bf902143d8ee8bfd2dd08c700289cc8801c9cb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -462,6 +462,21 @@ class StreamingContext private ( inputStream } +/** + * Create an input stream that receives messages pushed by a mqtt publisher. + * @param brokerUrl Url of remote mqtt publisher + * @param topic topic name to subscribe to + * @param storageLevel RDD storage level. Defaults to memory-only. + */ + + def mqttStream( + brokerUrl: String, + topic: String, + storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2): DStream[String] = { + val inputStream = new MQTTInputDStream[String](this, brokerUrl, topic, storageLevel) + registerInputStream(inputStream) + inputStream + } /** * Create a unified DStream from multiple DStreams of the same type and same slide duration. */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MQTTInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MQTTInputDStream.scala new file mode 100644 index 0000000000000000000000000000000000000000..ac0528213d3290832d458d4eea91992d99f3cbe9 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MQTTInputDStream.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.dstream + +import org.apache.spark.Logging +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{ Time, DStreamCheckpointData, StreamingContext } + +import java.util.Properties +import java.util.concurrent.Executors +import java.io.IOException + +import org.eclipse.paho.client.mqttv3.MqttCallback +import org.eclipse.paho.client.mqttv3.MqttClient +import org.eclipse.paho.client.mqttv3.MqttClientPersistence +import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence +import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken +import org.eclipse.paho.client.mqttv3.MqttException +import org.eclipse.paho.client.mqttv3.MqttMessage +import org.eclipse.paho.client.mqttv3.MqttTopic + +import scala.collection.Map +import scala.collection.mutable.HashMap +import scala.collection.JavaConversions._ + +/** + * Input stream that subscribe messages from a Mqtt Broker. + * Uses eclipse paho as MqttClient http://www.eclipse.org/paho/ + * @param brokerUrl Url of remote mqtt publisher + * @param topic topic name to subscribe to + * @param storageLevel RDD storage level. + */ + +private[streaming] +class MQTTInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + brokerUrl: String, + topic: String, + storageLevel: StorageLevel + ) extends NetworkInputDStream[T](ssc_) with Logging { + + def getReceiver(): NetworkReceiver[T] = { + new MQTTReceiver(brokerUrl, topic, storageLevel) + .asInstanceOf[NetworkReceiver[T]] + } +} + +private[streaming] +class MQTTReceiver(brokerUrl: String, + topic: String, + storageLevel: StorageLevel + ) extends NetworkReceiver[Any] { + lazy protected val blockGenerator = new BlockGenerator(storageLevel) + + def onStop() { + blockGenerator.stop() + } + + def onStart() { + + blockGenerator.start() + + // Set up persistence for messages + var peristance: MqttClientPersistence = new MemoryPersistence() + + // Initializing Mqtt Client specifying brokerUrl, clientID and MqttClientPersistance + var client: MqttClient = new MqttClient(brokerUrl, "MQTTSub", peristance) + + // Connect to MqttBroker + client.connect() + + // Subscribe to Mqtt topic + client.subscribe(topic) + + // Callback automatically triggers as and when new message arrives on specified topic + var callback: MqttCallback = new MqttCallback() { + + // Handles Mqtt message + override def messageArrived(arg0: String, arg1: MqttMessage) { + blockGenerator += new String(arg1.getPayload()) + } + + override def deliveryComplete(arg0: IMqttDeliveryToken) { + } + + override def connectionLost(arg0: Throwable) { + logInfo("Connection lost " + arg0) + } + } + + // Set up callback for MqttClient + client.setCallback(callback) + } +}