diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index 0bc5b2ff112eedc40b2cdcc81cd8300642f3e021..f5a3c2990beb09b40386de7b9b9f395649fd48bf 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -2,6 +2,7 @@ package spark import scala.collection.immutable.NumericRange import scala.collection.mutable.ArrayBuffer +import scala.collection.Map private[spark] class ParallelCollectionSplit[T: ClassManifest]( val rddId: Long, @@ -24,7 +25,8 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest]( private[spark] class ParallelCollection[T: ClassManifest]( @transient sc : SparkContext, @transient data: Seq[T], - numSlices: Int) + numSlices: Int, + locationPrefs : Map[Int,Seq[String]]) extends RDD[T](sc, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split @@ -41,6 +43,13 @@ private[spark] class ParallelCollection[T: ClassManifest]( override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator + override def getPreferredLocations(s: Split): Seq[String] = { + locationPrefs.get(s.index) match { + case Some(s) => s + case _ => Nil + } + } + override def clearDependencies() { splits_ = null } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 362aa04e66bf87a381f43f05e37d1a1a90662aa4..709b5c38d9394af1613b58fd15b1615de76c62d0 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -190,7 +190,7 @@ class SparkContext( /** Distribute a local Scala collection to form an RDD. */ def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { - new ParallelCollection[T](this, seq, numSlices) + new ParallelCollection[T](this, seq, numSlices, Map[Int, Seq[String]]()) } /** Distribute a local Scala collection to form an RDD. */ @@ -198,6 +198,14 @@ class SparkContext( parallelize(seq, numSlices) } + /** Distribute a local Scala collection to form an RDD, with one or more + * location preferences (hostnames of Spark nodes) for each object. + * Create a new partition for each collection item. */ + def makeRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = { + val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap + new ParallelCollection[T](this, seq.map(_._1), seq.size, indexToPrefs) + } + /** * Read a text file from HDFS, a local file system (available on all nodes), or any * Hadoop-supported file system URI, and return it as an RDD of Strings. diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 6bc667bd4c169cefb83b1b5806b3eca52ec962ab..1f025d175b7bfd9d8ff2f6090fcc6ed7bb0e0ee0 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -186,19 +186,19 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { // Test whether the checkpoint file has been created assert(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get).collect() === result) - + // Test whether dependencies have been changed from its earlier parent RDD assert(operatedRDD.dependencies.head.rdd != parentRDD) - + // Test whether the splits have been changed to the new Hadoop splits assert(operatedRDD.splits.toList === operatedRDD.checkpointData.get.getSplits.toList) // Test whether the number of splits is same as before assert(operatedRDD.splits.length === numSplits) - + // Test whether the data in the checkpointed RDD is same as original assert(operatedRDD.collect() === result) - + // Test whether serialized size of the RDD has reduced. If the RDD // does not have any dependency to another RDD (e.g., ParallelCollection, // ShuffleRDD with ShuffleDependency), it may not reduce in size after checkpointing. @@ -211,7 +211,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" ) } - + // Test whether serialized size of the splits has reduced. If the splits // do not have any non-transient reference to another RDD or another RDD's splits, it // does not refer to a lineage and therefore may not reduce in size after checkpointing. @@ -255,7 +255,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { // Test whether the data in the checkpointed RDD is same as original assert(operatedRDD.collect() === result) - + // Test whether serialized size of the RDD has reduced because of its parent being // checkpointed. If this RDD or its parent RDD do not have any dependency // to another RDD (e.g., ParallelCollection, ShuffleRDD with ShuffleDependency), it may @@ -267,7 +267,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]" ) } - + // Test whether serialized size of the splits has reduced because of its parent being // checkpointed. If the splits do not have any non-transient reference to another RDD // or another RDD's splits, it does not refer to a lineage and therefore may not reduce @@ -281,7 +281,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]" ) } - + } /** diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 5875506179fbe00e5ba2757adbf473132d1a23ae..6bd9836a936510cd7a0756d07dde6aee0fffab79 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -44,6 +44,8 @@ public class JavaAPISuite implements Serializable { public void tearDown() { sc.stop(); sc = null; + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.master.port"); } static class ReverseIntComparator implements Comparator<Integer>, Serializable { diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 688bb16a03117a48ed9eda17c111824766e328e2..05f3c59681ecb5a0ab1d3805c6611b0330bd23a8 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -91,7 +91,8 @@ object SparkBuild extends Build { "org.eclipse.jetty" % "jetty-server" % "7.5.3.v20111011", "org.scalatest" %% "scalatest" % "1.6.1" % "test", "org.scalacheck" %% "scalacheck" % "1.9" % "test", - "com.novocode" % "junit-interface" % "0.8" % "test" + "com.novocode" % "junit-interface" % "0.8" % "test", + "org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" ), parallelExecution := false, /* Workaround for issue #206 (fixed after SBT 0.11.0) */ @@ -156,7 +157,9 @@ object SparkBuild extends Build { def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel") def streamingSettings = sharedSettings ++ Seq( - name := "spark-streaming" + name := "spark-streaming", + libraryDependencies ++= Seq( + "com.github.sgroschupf" % "zkclient" % "0.1") ) ++ assemblySettings ++ extraAssemblySettings def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq( diff --git a/streaming/lib/kafka-0.7.2.jar b/streaming/lib/kafka-0.7.2.jar new file mode 100644 index 0000000000000000000000000000000000000000..65f79925a4d06a41b7b98d7e6f92fedb408c9b3a Binary files /dev/null and b/streaming/lib/kafka-0.7.2.jar differ diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 69fefa21a0e90a54f07ef15ab9b3266107014c32..d5048aeed7b06e5589fb8c91028ba90d1f4a1769 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -37,6 +37,9 @@ import org.apache.hadoop.conf.Configuration * - A time interval at which the DStream generates an RDD * - A function that is used to generate an RDD after each time interval */ + +case class DStreamCheckpointData(rdds: HashMap[Time, Any]) + abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext) extends Serializable with Logging { @@ -79,7 +82,7 @@ extends Serializable with Logging { // Checkpoint details protected[streaming] val mustCheckpoint = false protected[streaming] var checkpointInterval: Time = null - protected[streaming] val checkpointData = new HashMap[Time, Any]() + protected[streaming] var checkpointData = new DStreamCheckpointData(HashMap[Time, Any]()) // Reference to whole DStream graph protected[streaming] var graph: DStreamGraph = null @@ -314,6 +317,15 @@ extends Serializable with Logging { dependencies.foreach(_.forgetOldRDDs(time)) } + /* Adds metadata to the Stream while it is running. + * This methd should be overwritten by sublcasses of InputDStream. + */ + protected[streaming] def addMetadata(metadata: Any) { + if (metadata != null) { + logInfo("Dropping Metadata: " + metadata.toString) + } + } + /** * Refreshes the list of checkpointed RDDs that will be saved along with checkpoint of * this stream. This is an internal method that should not be called directly. This is @@ -325,29 +337,28 @@ extends Serializable with Logging { logInfo("Updating checkpoint data for time " + currentTime) // Get the checkpointed RDDs from the generated RDDs + val newRdds = generatedRDDs.filter(_._2.getCheckpointFile.isDefined) + .map(x => (x._1, x._2.getCheckpointFile.get)) - val newCheckpointData = generatedRDDs.filter(_._2.getCheckpointFile.isDefined) - .map(x => (x._1, x._2.getCheckpointFile.get)) - // Make a copy of the existing checkpoint data - val oldCheckpointData = checkpointData.clone() + // Make a copy of the existing checkpoint data (checkpointed RDDs) + val oldRdds = checkpointData.rdds.clone() - // If the new checkpoint has checkpoints then replace existing with the new one - if (newCheckpointData.size > 0) { - checkpointData.clear() - checkpointData ++= newCheckpointData + // If the new checkpoint data has checkpoints then replace existing with the new one + if (newRdds.size > 0) { + checkpointData.rdds.clear() + checkpointData.rdds ++= newRdds } - // Make dependencies update their checkpoint data + // Make parent DStreams update their checkpoint data dependencies.foreach(_.updateCheckpointData(currentTime)) // TODO: remove this, this is just for debugging - newCheckpointData.foreach { + newRdds.foreach { case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") } } - // If old checkpoint files have been removed from checkpoint data, then remove the files - if (newCheckpointData.size > 0) { - (oldCheckpointData -- newCheckpointData.keySet).foreach { + if (newRdds.size > 0) { + (oldRdds -- newRdds.keySet).foreach { case (time, data) => { val path = new Path(data.toString) val fs = path.getFileSystem(new Configuration()) @@ -356,8 +367,8 @@ extends Serializable with Logging { } } } - logInfo("Updated checkpoint data for time " + currentTime + ", " + checkpointData.size + " checkpoints, " - + "[" + checkpointData.mkString(",") + "]") + logInfo("Updated checkpoint data for time " + currentTime + ", " + checkpointData.rdds.size + " checkpoints, " + + "[" + checkpointData.rdds.mkString(",") + "]") } /** @@ -368,8 +379,8 @@ extends Serializable with Logging { */ protected[streaming] def restoreCheckpointData() { // Create RDDs from the checkpoint data - logInfo("Restoring checkpoint data from " + checkpointData.size + " checkpointed RDDs") - checkpointData.foreach { + logInfo("Restoring checkpoint data from " + checkpointData.rdds.size + " checkpointed RDDs") + checkpointData.rdds.foreach { case(time, data) => { logInfo("Restoring checkpointed RDD for time " + time + " from file '" + data.toString + "'") val rdd = ssc.sc.checkpointFile[T](data.toString) diff --git a/streaming/src/main/scala/spark/streaming/DataHandler.scala b/streaming/src/main/scala/spark/streaming/DataHandler.scala new file mode 100644 index 0000000000000000000000000000000000000000..05f307a8d1296608dafefca40bd11b9f7f5ddb21 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/DataHandler.scala @@ -0,0 +1,83 @@ +package spark.streaming + +import java.util.concurrent.ArrayBlockingQueue +import scala.collection.mutable.ArrayBuffer +import spark.Logging +import spark.streaming.util.{RecurringTimer, SystemClock} +import spark.storage.StorageLevel + + +/** + * This is a helper object that manages the data received from the socket. It divides + * the object received into small batches of 100s of milliseconds, pushes them as + * blocks into the block manager and reports the block IDs to the network input + * tracker. It starts two threads, one to periodically start a new batch and prepare + * the previous batch of as a block, the other to push the blocks into the block + * manager. + */ + class DataHandler[T](receiver: NetworkReceiver[T], storageLevel: StorageLevel) + extends Serializable with Logging { + + case class Block(id: String, iterator: Iterator[T], metadata: Any = null) + + val clock = new SystemClock() + val blockInterval = 200L + val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) + val blockStorageLevel = storageLevel + val blocksForPushing = new ArrayBlockingQueue[Block](1000) + val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + var currentBuffer = new ArrayBuffer[T] + + def createBlock(blockId: String, iterator: Iterator[T]) : Block = { + new Block(blockId, iterator) + } + + def start() { + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Data handler started") + } + + def stop() { + blockIntervalTimer.stop() + blockPushingThread.interrupt() + logInfo("Data handler stopped") + } + + def += (obj: T) { + currentBuffer += obj + } + + def updateCurrentBuffer(time: Long) { + try { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[T] + if (newBlockBuffer.size > 0) { + val blockId = "input-" + receiver.streamId + "- " + (time - blockInterval) + val newBlock = createBlock(blockId, newBlockBuffer.toIterator) + blocksForPushing.add(newBlock) + } + } catch { + case ie: InterruptedException => + logInfo("Block interval timer thread interrupted") + case e: Exception => + receiver.stop() + } + } + + def keepPushingBlocks() { + logInfo("Block pushing thread started") + try { + while(true) { + val block = blocksForPushing.take() + receiver.pushBlock(block.id, block.iterator, block.metadata, storageLevel) + } + } catch { + case ie: InterruptedException => + logInfo("Block pushing thread interrupted") + case e: Exception => + receiver.stop() + } + } + } \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala new file mode 100644 index 0000000000000000000000000000000000000000..2959ce4540de17b10e7a080ff6ebc44f46c2e5e3 --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/FlumeInputDStream.scala @@ -0,0 +1,130 @@ +package spark.streaming + +import java.io.{ObjectInput, ObjectOutput, Externalizable} +import spark.storage.StorageLevel +import org.apache.flume.source.avro.AvroSourceProtocol +import org.apache.flume.source.avro.AvroFlumeEvent +import org.apache.flume.source.avro.Status +import org.apache.avro.ipc.specific.SpecificResponder +import org.apache.avro.ipc.NettyServer +import java.net.InetSocketAddress +import collection.JavaConversions._ +import spark.Utils +import java.nio.ByteBuffer + +class FlumeInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + host: String, + port: Int, + storageLevel: StorageLevel +) extends NetworkInputDStream[SparkFlumeEvent](ssc_) { + + override def createReceiver(): NetworkReceiver[SparkFlumeEvent] = { + new FlumeReceiver(id, host, port, storageLevel) + } +} + +/** + * A wrapper class for AvroFlumeEvent's with a custom serialization format. + * + * This is necessary because AvroFlumeEvent uses inner data structures + * which are not serializable. + */ +class SparkFlumeEvent() extends Externalizable { + var event : AvroFlumeEvent = new AvroFlumeEvent() + + /* De-serialize from bytes. */ + def readExternal(in: ObjectInput) { + val bodyLength = in.readInt() + val bodyBuff = new Array[Byte](bodyLength) + in.read(bodyBuff) + + val numHeaders = in.readInt() + val headers = new java.util.HashMap[CharSequence, CharSequence] + + for (i <- 0 until numHeaders) { + val keyLength = in.readInt() + val keyBuff = new Array[Byte](keyLength) + in.read(keyBuff) + val key : String = Utils.deserialize(keyBuff) + + val valLength = in.readInt() + val valBuff = new Array[Byte](valLength) + in.read(valBuff) + val value : String = Utils.deserialize(valBuff) + + headers.put(key, value) + } + + event.setBody(ByteBuffer.wrap(bodyBuff)) + event.setHeaders(headers) + } + + /* Serialize to bytes. */ + def writeExternal(out: ObjectOutput) { + val body = event.getBody.array() + out.writeInt(body.length) + out.write(body) + + val numHeaders = event.getHeaders.size() + out.writeInt(numHeaders) + for ((k, v) <- event.getHeaders) { + val keyBuff = Utils.serialize(k.toString) + out.writeInt(keyBuff.length) + out.write(keyBuff) + val valBuff = Utils.serialize(v.toString) + out.writeInt(valBuff.length) + out.write(valBuff) + } + } +} + +object SparkFlumeEvent { + def fromAvroFlumeEvent(in : AvroFlumeEvent) : SparkFlumeEvent = { + val event = new SparkFlumeEvent + event.event = in + event + } +} + +/** A simple server that implements Flume's Avro protocol. */ +class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { + override def append(event : AvroFlumeEvent) : Status = { + receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event) + Status.OK + } + + override def appendBatch(events : java.util.List[AvroFlumeEvent]) : Status = { + events.foreach (event => + receiver.dataHandler += SparkFlumeEvent.fromAvroFlumeEvent(event)) + Status.OK + } +} + +/** A NetworkReceiver which listens for events using the + * Flume Avro interface.*/ +class FlumeReceiver( + streamId: Int, + host: String, + port: Int, + storageLevel: StorageLevel + ) extends NetworkReceiver[SparkFlumeEvent](streamId) { + + lazy val dataHandler = new DataHandler(this, storageLevel) + + protected override def onStart() { + val responder = new SpecificResponder( + classOf[AvroSourceProtocol], new FlumeEventServer(this)); + val server = new NettyServer(responder, new InetSocketAddress(host, port)); + dataHandler.start() + server.start() + logInfo("Flume receiver started") + } + + protected override def onStop() { + dataHandler.stop() + logInfo("Flume receiver stopped") + } + + override def getLocationPreference = Some(host) +} \ No newline at end of file diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala index f3f4c3ab138867233de1ca255619fbc07d254100..4e4e9fc9421057670a4097604aace0aad2e29d27 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputDStream.scala @@ -4,6 +4,7 @@ import scala.collection.mutable.ArrayBuffer import spark.{Logging, SparkEnv, RDD} import spark.rdd.BlockRDD +import spark.streaming.util.{RecurringTimer, SystemClock} import spark.storage.StorageLevel import java.nio.ByteBuffer @@ -41,10 +42,10 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming sealed trait NetworkReceiverMessage case class StopReceiver(msg: String) extends NetworkReceiverMessage -case class ReportBlock(blockId: String) extends NetworkReceiverMessage +case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage case class ReportError(msg: String) extends NetworkReceiverMessage -abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializable with Logging { +abstract class NetworkReceiver[T: ClassManifest](val streamId: Int) extends Serializable with Logging { initLogging() @@ -61,6 +62,9 @@ abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializ /** This method will be called to stop receiving data. */ protected def onStop() + /** This method conveys a placement preference (hostname) for this receiver. */ + def getLocationPreference() : Option[String] = None + /** * This method starts the receiver. First is accesses all the lazy members to * materialize them. Then it calls the user-defined onStart() method to start @@ -106,21 +110,23 @@ abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializ actor ! ReportError(e.toString) } + /** * This method pushes a block (as iterator of values) into the block manager. */ - protected def pushBlock(blockId: String, iterator: Iterator[T], level: StorageLevel) { + def pushBlock(blockId: String, iterator: Iterator[T], metadata: Any, level: StorageLevel) { val buffer = new ArrayBuffer[T] ++ iterator env.blockManager.put(blockId, buffer.asInstanceOf[ArrayBuffer[Any]], level) - actor ! ReportBlock(blockId) + + actor ! ReportBlock(blockId, metadata) } /** * This method pushes a block (as bytes) into the block manager. */ - protected def pushBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + def pushBlock(blockId: String, bytes: ByteBuffer, metadata: Any, level: StorageLevel) { env.blockManager.putBytes(blockId, bytes, level) - actor ! ReportBlock(blockId) + actor ! ReportBlock(blockId, metadata) } /** A helper actor that communicates with the NetworkInputTracker */ @@ -138,8 +144,8 @@ abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializ } override def receive() = { - case ReportBlock(blockId) => - tracker ! AddBlocks(streamId, Array(blockId)) + case ReportBlock(blockId, metadata) => + tracker ! AddBlocks(streamId, Array(blockId), metadata) case ReportError(msg) => tracker ! DeregisterReceiver(streamId, msg) case StopReceiver(msg) => @@ -148,4 +154,3 @@ abstract class NetworkReceiver[T: ClassManifest](streamId: Int) extends Serializ } } } - diff --git a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala index ae6692290e5cef2e2a22a88f4ad2861e30a671a5..b421f795ee39eda4a48b5d49444adccb0d6f5976 100644 --- a/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/spark/streaming/NetworkInputTracker.scala @@ -13,7 +13,7 @@ import akka.dispatch._ trait NetworkInputTrackerMessage case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage -case class AddBlocks(streamId: Int, blockIds: Seq[String]) extends NetworkInputTrackerMessage +case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage @@ -22,7 +22,7 @@ class NetworkInputTracker( @transient networkInputStreams: Array[NetworkInputDStream[_]]) extends Logging { - val networkInputStreamIds = networkInputStreams.map(_.id).toArray + val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*) val receiverExecutor = new ReceiverExecutor() val receiverInfo = new HashMap[Int, ActorRef] val receivedBlockIds = new HashMap[Int, Queue[String]] @@ -54,14 +54,14 @@ class NetworkInputTracker( private class NetworkInputTrackerActor extends Actor { def receive = { case RegisterReceiver(streamId, receiverActor) => { - if (!networkInputStreamIds.contains(streamId)) { + if (!networkInputStreamMap.contains(streamId)) { throw new Exception("Register received for unexpected id " + streamId) } receiverInfo += ((streamId, receiverActor)) logInfo("Registered receiver for network stream " + streamId + " from " + sender.path.address) sender ! true } - case AddBlocks(streamId, blockIds) => { + case AddBlocks(streamId, blockIds, metadata) => { val tmp = receivedBlockIds.synchronized { if (!receivedBlockIds.contains(streamId)) { receivedBlockIds += ((streamId, new Queue[String])) @@ -71,6 +71,7 @@ class NetworkInputTracker( tmp.synchronized { tmp ++= blockIds } + networkInputStreamMap(streamId).addMetadata(metadata) } case DeregisterReceiver(streamId, msg) => { receiverInfo -= streamId @@ -97,7 +98,18 @@ class NetworkInputTracker( def startReceivers() { val receivers = networkInputStreams.map(_.createReceiver()) - val tempRDD = ssc.sc.makeRDD(receivers, receivers.size) + + // Right now, we only honor preferences if all receivers have them + val hasLocationPreferences = receivers.map(_.getLocationPreference().isDefined).reduce(_ && _) + + val tempRDD = + if (hasLocationPreferences) { + val receiversWithPreferences = receivers.map(r => (r, Seq(r.getLocationPreference().toString))) + ssc.sc.makeRDD[NetworkReceiver[_]](receiversWithPreferences) + } + else { + ssc.sc.makeRDD(receivers, receivers.size) + } val startReceiver = (iterator: Iterator[NetworkReceiver[_]]) => { if (!iterator.hasNext) { diff --git a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala index 03726bfba6f18bbf9574bca120fd9e2c8a6e8d75..6acaa9aab136a987a0791ab89e7ed75fe1467c5c 100644 --- a/streaming/src/main/scala/spark/streaming/RawInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/RawInputDStream.scala @@ -31,6 +31,8 @@ class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: S var blockPushingThread: Thread = null + override def getLocationPreference = None + def onStart() { // Open a socket to the target address and keep reading from it logInfo("Connecting to " + host + ":" + port) @@ -48,7 +50,7 @@ class RawNetworkReceiver(streamId: Int, host: String, port: Int, storageLevel: S val buffer = queue.take() val blockId = "input-" + streamId + "-" + nextBlockNumber nextBlockNumber += 1 - pushBlock(blockId, buffer, storageLevel) + pushBlock(blockId, buffer, null, storageLevel) } } } diff --git a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala index b566200273df4eb2d0b7191f3b8ef8e0f97b66cb..a9e37c0ff0ee6324efcc251ff0000a71c0afcb9a 100644 --- a/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/SocketInputDStream.scala @@ -32,7 +32,9 @@ class SocketReceiver[T: ClassManifest]( storageLevel: StorageLevel ) extends NetworkReceiver[T](streamId) { - lazy protected val dataHandler = new DataHandler(this) + lazy protected val dataHandler = new DataHandler(this, storageLevel) + + override def getLocationPreference = None protected def onStart() { logInfo("Connecting to " + host + ":" + port) @@ -50,74 +52,6 @@ class SocketReceiver[T: ClassManifest]( dataHandler.stop() } - /** - * This is a helper object that manages the data received from the socket. It divides - * the object received into small batches of 100s of milliseconds, pushes them as - * blocks into the block manager and reports the block IDs to the network input - * tracker. It starts two threads, one to periodically start a new batch and prepare - * the previous batch of as a block, the other to push the blocks into the block - * manager. - */ - class DataHandler(receiver: NetworkReceiver[T]) extends Serializable { - case class Block(id: String, iterator: Iterator[T]) - - val clock = new SystemClock() - val blockInterval = 200L - val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) - val blockStorageLevel = storageLevel - val blocksForPushing = new ArrayBlockingQueue[Block](1000) - val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } - - var currentBuffer = new ArrayBuffer[T] - - def start() { - blockIntervalTimer.start() - blockPushingThread.start() - logInfo("Data handler started") - } - - def stop() { - blockIntervalTimer.stop() - blockPushingThread.interrupt() - logInfo("Data handler stopped") - } - - def += (obj: T) { - currentBuffer += obj - } - - def updateCurrentBuffer(time: Long) { - try { - val newBlockBuffer = currentBuffer - currentBuffer = new ArrayBuffer[T] - if (newBlockBuffer.size > 0) { - val blockId = "input-" + streamId + "- " + (time - blockInterval) - val newBlock = new Block(blockId, newBlockBuffer.toIterator) - blocksForPushing.add(newBlock) - } - } catch { - case ie: InterruptedException => - logInfo("Block interval timer thread interrupted") - case e: Exception => - receiver.stop() - } - } - - def keepPushingBlocks() { - logInfo("Block pushing thread started") - try { - while(true) { - val block = blocksForPushing.take() - pushBlock(block.id, block.iterator, storageLevel) - } - } catch { - case ie: InterruptedException => - logInfo("Block pushing thread interrupted") - case e: Exception => - receiver.stop() - } - } - } } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 9c19f6588df6ac606461a777412d700585f3b7ad..ce47bcb2da34b77b4c0c3c54a5475cb012e74701 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -15,6 +15,7 @@ import org.apache.hadoop.io.LongWritable import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat +import org.apache.flume.source.avro.AvroFlumeEvent import org.apache.hadoop.fs.Path import java.util.UUID import spark.util.MetadataCleaner @@ -122,6 +123,31 @@ class StreamingContext private ( private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() + /** + * Create an input stream that pulls messages form a Kafka Broker. + * + * @param host Zookeper hostname. + * @param port Zookeper port. + * @param groupId The group id for this consumer. + * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * in its own thread. + * @param initialOffsets Optional initial offsets for each of the partitions to consume. + * By default the value is pulled from zookeper. + * @param storageLevel RDD storage level. Defaults to memory-only. + */ + def kafkaStream[T: ClassManifest]( + hostname: String, + port: Int, + groupId: String, + topics: Map[String, Int], + initialOffsets: Map[KafkaPartitionKey, Long] = Map[KafkaPartitionKey, Long](), + storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2 + ): DStream[T] = { + val inputStream = new KafkaInputDStream[T](this, hostname, port, groupId, topics, initialOffsets, storageLevel) + graph.addInputStream(inputStream) + inputStream + } + def networkTextStream( hostname: String, port: Int, @@ -141,6 +167,16 @@ class StreamingContext private ( inputStream } + def flumeStream ( + hostname: String, + port: Int, + storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2): DStream[SparkFlumeEvent] = { + val inputStream = new FlumeInputDStream(this, hostname, port, storageLevel) + graph.addInputStream(inputStream) + inputStream + } + + def rawNetworkStream[T: ClassManifest]( hostname: String, port: Int, diff --git a/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala b/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala new file mode 100644 index 0000000000000000000000000000000000000000..e60ce483a3f9d6176fa043b2ff2181476c9de11e --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/FlumeEventCount.scala @@ -0,0 +1,43 @@ +package spark.streaming.examples + +import spark.util.IntParam +import spark.storage.StorageLevel +import spark.streaming._ + +/** + * Produce a streaming count of events received from Flume. + * + * This should be used in conjunction with an AvroSink in Flume. It will start + * an Avro server on at the request host:port address and listen for requests. + * Your Flume AvroSink should be pointed to this address. + * + * Usage: FlumeEventCount <master> <host> <port> + * + * <master> is a Spark master URL + * <host> is the host the Flume receiver will be started on - a receiver + * creates a server and listens for flume events. + * <port> is the port the Flume receiver will listen on. + */ +object FlumeEventCount { + def main(args: Array[String]) { + if (args.length != 3) { + System.err.println( + "Usage: FlumeEventCount <master> <host> <port>") + System.exit(1) + } + + val Array(master, host, IntParam(port)) = args + + val batchInterval = Milliseconds(2000) + // Create the context and set the batch size + val ssc = new StreamingContext(master, "FlumeEventCount", batchInterval) + + // Create a flume stream + val stream = ssc.flumeStream(host,port,StorageLevel.MEMORY_ONLY) + + // Print out the count of events received from this server in each batch + stream.count().map(cnt => "Received " + cnt + " flume events." ).print() + + ssc.start() + } +} diff --git a/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala new file mode 100644 index 0000000000000000000000000000000000000000..fe55db6e2c6314739cf3c3e2fd4f34511aff7b6f --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/examples/KafkaWordCount.scala @@ -0,0 +1,69 @@ +package spark.streaming.examples + +import java.util.Properties +import kafka.message.Message +import kafka.producer.SyncProducerConfig +import kafka.producer._ +import spark.SparkContext +import spark.streaming._ +import spark.streaming.StreamingContext._ +import spark.storage.StorageLevel +import spark.streaming.util.RawTextHelper._ + +object KafkaWordCount { + def main(args: Array[String]) { + + if (args.length < 6) { + System.err.println("Usage: KafkaWordCount <master> <hostname> <port> <group> <topics> <numThreads>") + System.exit(1) + } + + val Array(master, hostname, port, group, topics, numThreads) = args + + val sc = new SparkContext(master, "KafkaWordCount") + val ssc = new StreamingContext(sc, Seconds(2)) + ssc.checkpoint("checkpoint") + + val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap + val lines = ssc.kafkaStream[String](hostname, port.toInt, group, topicpMap) + val words = lines.flatMap(_.split(" ")) + val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2) + wordCounts.print() + + ssc.start() + } +} + +// Produces some random words between 1 and 100. +object KafkaWordCountProducer { + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: KafkaWordCountProducer <hostname> <port> <topic> <messagesPerSec> <wordsPerMessage>") + System.exit(1) + } + + val Array(hostname, port, topic, messagesPerSec, wordsPerMessage) = args + + // Zookeper connection properties + val props = new Properties() + props.put("zk.connect", hostname + ":" + port) + props.put("serializer.class", "kafka.serializer.StringEncoder") + + val config = new ProducerConfig(props) + val producer = new Producer[String, String](config) + + // Send some messages + while(true) { + val messages = (1 to messagesPerSec.toInt).map { messageNum => + (1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString).mkString(" ") + }.toArray + println(messages.mkString(",")) + val data = new ProducerData[String, String](topic, messages) + producer.send(data) + Thread.sleep(100) + } + } + +} + diff --git a/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala new file mode 100644 index 0000000000000000000000000000000000000000..7c642d4802d892d656b0291effb6c9ac58e9b0ee --- /dev/null +++ b/streaming/src/main/scala/spark/streaming/input/KafkaInputDStream.scala @@ -0,0 +1,193 @@ +package spark.streaming + +import java.util.Properties +import java.util.concurrent.Executors +import kafka.consumer._ +import kafka.message.{Message, MessageSet, MessageAndMetadata} +import kafka.serializer.StringDecoder +import kafka.utils.{Utils, ZKGroupTopicDirs} +import kafka.utils.ZkUtils._ +import scala.collection.mutable.HashMap +import scala.collection.JavaConversions._ +import spark._ +import spark.RDD +import spark.storage.StorageLevel + +// Key for a specific Kafka Partition: (broker, topic, group, part) +case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int) +// NOT USED - Originally intended for fault-tolerance +// Metadata for a Kafka Stream that it sent to the Master +case class KafkaInputDStreamMetadata(timestamp: Long, data: Map[KafkaPartitionKey, Long]) +// NOT USED - Originally intended for fault-tolerance +// Checkpoint data specific to a KafkaInputDstream +case class KafkaDStreamCheckpointData(kafkaRdds: HashMap[Time, Any], + savedOffsets: Map[KafkaPartitionKey, Long]) extends DStreamCheckpointData(kafkaRdds) + +/** + * Input stream that pulls messages form a Kafka Broker. + * + * @param host Zookeper hostname. + * @param port Zookeper port. + * @param groupId The group id for this consumer. + * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed + * in its own thread. + * @param initialOffsets Optional initial offsets for each of the partitions to consume. + * By default the value is pulled from zookeper. + * @param storageLevel RDD storage level. + */ +class KafkaInputDStream[T: ClassManifest]( + @transient ssc_ : StreamingContext, + host: String, + port: Int, + groupId: String, + topics: Map[String, Int], + initialOffsets: Map[KafkaPartitionKey, Long], + storageLevel: StorageLevel + ) extends NetworkInputDStream[T](ssc_ ) with Logging { + + // Metadata that keeps track of which messages have already been consumed. + var savedOffsets = HashMap[Long, Map[KafkaPartitionKey, Long]]() + + /* NOT USED - Originally intended for fault-tolerance + + // In case of a failure, the offets for a particular timestamp will be restored. + @transient var restoredOffsets : Map[KafkaPartitionKey, Long] = null + + + override protected[streaming] def addMetadata(metadata: Any) { + metadata match { + case x : KafkaInputDStreamMetadata => + savedOffsets(x.timestamp) = x.data + // TOOD: Remove logging + logInfo("New saved Offsets: " + savedOffsets) + case _ => logInfo("Received unknown metadata: " + metadata.toString) + } + } + + override protected[streaming] def updateCheckpointData(currentTime: Time) { + super.updateCheckpointData(currentTime) + if(savedOffsets.size > 0) { + // Find the offets that were stored before the checkpoint was initiated + val key = savedOffsets.keys.toList.sortWith(_ < _).filter(_ < currentTime.millis).last + val latestOffsets = savedOffsets(key) + logInfo("Updating KafkaDStream checkpoint data: " + latestOffsets.toString) + checkpointData = KafkaDStreamCheckpointData(checkpointData.rdds, latestOffsets) + // TODO: This may throw out offsets that are created after the checkpoint, + // but it's unlikely we'll need them. + savedOffsets.clear() + } + } + + override protected[streaming] def restoreCheckpointData() { + super.restoreCheckpointData() + logInfo("Restoring KafkaDStream checkpoint data.") + checkpointData match { + case x : KafkaDStreamCheckpointData => + restoredOffsets = x.savedOffsets + logInfo("Restored KafkaDStream offsets: " + savedOffsets) + } + } */ + + def createReceiver(): NetworkReceiver[T] = { + new KafkaReceiver(id, host, port, groupId, topics, initialOffsets, storageLevel) + .asInstanceOf[NetworkReceiver[T]] + } +} + +class KafkaReceiver(streamId: Int, host: String, port: Int, groupId: String, + topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long], + storageLevel: StorageLevel) extends NetworkReceiver[Any](streamId) { + + // Timeout for establishing a connection to Zookeper in ms. + val ZK_TIMEOUT = 10000 + + // Handles pushing data into the BlockManager + lazy protected val dataHandler = new DataHandler(this, storageLevel) + // Keeps track of the current offsets. Maps from (broker, topic, group, part) -> Offset + lazy val offsets = HashMap[KafkaPartitionKey, Long]() + // Connection to Kafka + var consumerConnector : ZookeeperConsumerConnector = null + + def onStop() { + dataHandler.stop() + } + + def onStart() { + + // Starting the DataHandler that buffers blocks and pushes them into them BlockManager + dataHandler.start() + + // In case we are using multiple Threads to handle Kafka Messages + val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) + + val zooKeeperEndPoint = host + ":" + port + logInfo("Starting Kafka Consumer Stream with group: " + groupId) + logInfo("Initial offsets: " + initialOffsets.toString) + + // Zookeper connection properties + val props = new Properties() + props.put("zk.connect", zooKeeperEndPoint) + props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString) + props.put("groupid", groupId) + + // Create the connection to the cluster + logInfo("Connecting to Zookeper: " + zooKeeperEndPoint) + val consumerConfig = new ConsumerConfig(props) + consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector] + logInfo("Connected to " + zooKeeperEndPoint) + + // Reset the Kafka offsets in case we are recovering from a failure + resetOffsets(initialOffsets) + + // Create Threads for each Topic/Message Stream we are listening + val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder()) + + // Start the messages handler for each partition + topicMessageStreams.values.foreach { streams => + streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) } + } + + } + + // Overwrites the offets in Zookeper. + private def resetOffsets(offsets: Map[KafkaPartitionKey, Long]) { + offsets.foreach { case(key, offset) => + val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic) + val partitionName = key.brokerId + "-" + key.partId + updatePersistentPath(consumerConnector.zkClient, + topicDirs.consumerOffsetDir + "/" + partitionName, offset.toString) + } + } + + // Handles Kafka Messages + private class MessageHandler(stream: KafkaStream[String]) extends Runnable { + def run() { + logInfo("Starting MessageHandler.") + stream.takeWhile { msgAndMetadata => + dataHandler += msgAndMetadata.message + + // Updating the offet. The key is (broker, topic, group, partition). + val key = KafkaPartitionKey(msgAndMetadata.topicInfo.brokerId, msgAndMetadata.topic, + groupId, msgAndMetadata.topicInfo.partition.partId) + val offset = msgAndMetadata.topicInfo.getConsumeOffset + offsets.put(key, offset) + // logInfo("Handled message: " + (key, offset).toString) + + // Keep on handling messages + true + } + } + } + + // NOT USED - Originally intended for fault-tolerance + // class KafkaDataHandler(receiver: KafkaReceiver, storageLevel: StorageLevel) + // extends DataHandler[Any](receiver, storageLevel) { + + // override def createBlock(blockId: String, iterator: Iterator[Any]) : Block = { + // // Creates a new Block with Kafka-specific Metadata + // new Block(blockId, iterator, KafkaInputDStreamMetadata(System.currentTimeMillis, offsets.toMap)) + // } + + // } + +} diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index b3afedf39fe45ba34f1c1c7ed567d3a0add7755b..0d82b2f1eaf4803ee49a7e278636d23d01b74eda 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -63,9 +63,9 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // then check whether some RDD has been checkpointed or not ssc.start() runStreamsWithRealDelay(ssc, firstNumBatches) - logInfo("Checkpoint data of state stream = \n[" + stateStream.checkpointData.mkString(",\n") + "]") - assert(!stateStream.checkpointData.isEmpty, "No checkpointed RDDs in state stream before first failure") - stateStream.checkpointData.foreach { + logInfo("Checkpoint data of state stream = \n[" + stateStream.checkpointData.rdds.mkString(",\n") + "]") + assert(!stateStream.checkpointData.rdds.isEmpty, "No checkpointed RDDs in state stream before first failure") + stateStream.checkpointData.rdds.foreach { case (time, data) => { val file = new File(data.toString) assert(file.exists(), "Checkpoint file '" + file +"' for time " + time + " for state stream before first failure does not exist") @@ -74,7 +74,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // Run till a further time such that previous checkpoint files in the stream would be deleted // and check whether the earlier checkpoint files are deleted - val checkpointFiles = stateStream.checkpointData.map(x => new File(x._2.toString)) + val checkpointFiles = stateStream.checkpointData.rdds.map(x => new File(x._2.toString)) runStreamsWithRealDelay(ssc, secondNumBatches) checkpointFiles.foreach(file => assert(!file.exists, "Checkpoint file '" + file + "' was not deleted")) ssc.stop() @@ -91,8 +91,8 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { // is present in the checkpoint data or not ssc.start() runStreamsWithRealDelay(ssc, 1) - assert(!stateStream.checkpointData.isEmpty, "No checkpointed RDDs in state stream before second failure") - stateStream.checkpointData.foreach { + assert(!stateStream.checkpointData.rdds.isEmpty, "No checkpointed RDDs in state stream before second failure") + stateStream.checkpointData.rdds.foreach { case (time, data) => { val file = new File(data.toString) assert(file.exists(), diff --git a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala index e98c0967250e18877c33e6317206b497dc8ff461..ed9a6590923f103d9438062bbba12c284168bdd3 100644 --- a/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/InputStreamsSuite.scala @@ -1,6 +1,6 @@ package spark.streaming -import java.net.{SocketException, Socket, ServerSocket} +import java.net.{InetSocketAddress, SocketException, Socket, ServerSocket} import java.io.{File, BufferedWriter, OutputStreamWriter} import java.util.concurrent.{TimeUnit, ArrayBlockingQueue} import collection.mutable.{SynchronizedBuffer, ArrayBuffer} @@ -10,7 +10,14 @@ import spark.Logging import scala.util.Random import org.apache.commons.io.FileUtils import org.scalatest.BeforeAndAfter - +import org.apache.flume.source.avro.AvroSourceProtocol +import org.apache.flume.source.avro.AvroFlumeEvent +import org.apache.flume.source.avro.Status +import org.apache.avro.ipc.{specific, NettyTransceiver} +import org.apache.avro.ipc.specific.SpecificRequestor +import java.nio.ByteBuffer +import collection.JavaConversions._ +import java.nio.charset.Charset class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { @@ -123,6 +130,54 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { ssc.stop() } + test("flume input stream") { + // Set up the streaming context and input streams + val ssc = new StreamingContext(master, framework, batchDuration) + val flumeStream = ssc.flumeStream("localhost", 33333, StorageLevel.MEMORY_AND_DISK) + val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] + with SynchronizedBuffer[Seq[SparkFlumeEvent]] + val outputStream = new TestOutputStream(flumeStream, outputBuffer) + ssc.registerOutputStream(outputStream) + ssc.start() + + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val input = Seq(1, 2, 3, 4, 5) + + val transceiver = new NettyTransceiver(new InetSocketAddress("localhost", 33333)); + val client = SpecificRequestor.getClient( + classOf[AvroSourceProtocol], transceiver); + + for (i <- 0 until input.size) { + val event = new AvroFlumeEvent + event.setBody(ByteBuffer.wrap(input(i).toString.getBytes())) + event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) + client.append(event) + Thread.sleep(500) + clock.addToTime(batchDuration.milliseconds) + } + + val startTime = System.currentTimeMillis() + while (outputBuffer.size < input.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + logInfo("output.size = " + outputBuffer.size + ", input.size = " + input.size) + Thread.sleep(100) + } + Thread.sleep(1000) + val timeTaken = System.currentTimeMillis() - startTime + assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") + logInfo("Stopping context") + ssc.stop() + + val decoder = Charset.forName("UTF-8").newDecoder() + + assert(outputBuffer.size === input.length) + for (i <- 0 until outputBuffer.size) { + assert(outputBuffer(i).size === 1) + val str = decoder.decode(outputBuffer(i).head.event.getBody) + assert(str.toString === input(i).toString) + assert(outputBuffer(i).head.event.getHeaders.get("test") === "header") + } + } + test("file input stream") { // Create a temporary directory