diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index 78097502bca48d207a65563718079c638490d987..a6322dc58dc78b1ad28201fdbce27bd8bfb2b0a3 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -16,6 +16,7 @@ import org.apache.hadoop.mapred.Reporter import org.apache.hadoop.util.ReflectionUtils import spark.{Dependency, Logging, Partition, RDD, SerializableWritable, SparkContext, TaskContext} +import spark.util.NextIterator /** @@ -62,7 +63,7 @@ class HadoopRDD[K, V]( .asInstanceOf[InputFormat[K, V]] } - override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] { + override def compute(theSplit: Partition, context: TaskContext) = new NextIterator[(K, V)] { val split = theSplit.asInstanceOf[HadoopPartition] var reader: RecordReader[K, V] = null @@ -71,38 +72,22 @@ class HadoopRDD[K, V]( reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL) // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback{ () => close() } + context.addOnCompleteCallback{ () => closeIfNeeded() } val key: K = reader.createKey() val value: V = reader.createValue() - var gotNext = false - var finished = false - - override def hasNext: Boolean = { - if (!gotNext) { - try { - finished = !reader.next(key, value) - } catch { - case eof: EOFException => - finished = true - } - gotNext = true - } - !finished - } - override def next: (K, V) = { - if (!gotNext) { + override def getNext() = { + try { finished = !reader.next(key, value) + } catch { + case eof: EOFException => + finished = true } - if (finished) { - throw new NoSuchElementException("End of stream") - } - gotNext = false (key, value) } - private def close() { + override def close() { try { reader.close() } catch { diff --git a/core/src/main/scala/spark/serializer/Serializer.scala b/core/src/main/scala/spark/serializer/Serializer.scala index 50b086125a55e2217f5877c54c708fae8f4e8a3b..aca86ab6f0f1d93390bab2eeb2a197d2ca03a856 100644 --- a/core/src/main/scala/spark/serializer/Serializer.scala +++ b/core/src/main/scala/spark/serializer/Serializer.scala @@ -72,40 +72,18 @@ trait DeserializationStream { * Read the elements of this stream through an iterator. This can only be called once, as * reading each element will consume data from the input source. */ - def asIterator: Iterator[Any] = new Iterator[Any] { - var gotNext = false - var finished = false - var nextValue: Any = null - - private def getNext() { + def asIterator: Iterator[Any] = new spark.util.NextIterator[Any] { + override protected def getNext() = { try { - nextValue = readObject[Any]() + readObject[Any]() } catch { case eof: EOFException => finished = true } - gotNext = true } - override def hasNext: Boolean = { - if (!gotNext) { - getNext() - } - if (finished) { - close() - } - !finished - } - - override def next(): Any = { - if (!gotNext) { - getNext() - } - if (finished) { - throw new NoSuchElementException("End of stream") - } - gotNext = false - nextValue + override protected def close() { + DeserializationStream.this.close() } } } diff --git a/core/src/main/scala/spark/util/NextIterator.scala b/core/src/main/scala/spark/util/NextIterator.scala new file mode 100644 index 0000000000000000000000000000000000000000..48b5018dddbdd1326a7e7ff7bda771bc44c76ce0 --- /dev/null +++ b/core/src/main/scala/spark/util/NextIterator.scala @@ -0,0 +1,71 @@ +package spark.util + +/** Provides a basic/boilerplate Iterator implementation. */ +private[spark] abstract class NextIterator[U] extends Iterator[U] { + + private var gotNext = false + private var nextValue: U = _ + private var closed = false + protected var finished = false + + /** + * Method for subclasses to implement to provide the next element. + * + * If no next element is available, the subclass should set `finished` + * to `true` and may return any value (it will be ignored). + * + * This convention is required because `null` may be a valid value, + * and using `Option` seems like it might create unnecessary Some/None + * instances, given some iterators might be called in a tight loop. + * + * @return U, or set 'finished' when done + */ + protected def getNext(): U + + /** + * Method for subclasses to implement when all elements have been successfully + * iterated, and the iteration is done. + * + * <b>Note:</b> `NextIterator` cannot guarantee that `close` will be + * called because it has no control over what happens when an exception + * happens in the user code that is calling hasNext/next. + * + * Ideally you should have another try/catch, as in HadoopRDD, that + * ensures any resources are closed should iteration fail. + */ + protected def close() + + /** + * Calls the subclass-defined close method, but only once. + * + * Usually calling `close` multiple times should be fine, but historically + * there have been issues with some InputFormats throwing exceptions. + */ + def closeIfNeeded() { + if (!closed) { + close() + closed = true + } + } + + override def hasNext: Boolean = { + if (!finished) { + if (!gotNext) { + nextValue = getNext() + if (finished) { + closeIfNeeded() + } + gotNext = true + } + } + !finished + } + + override def next(): U = { + if (!hasNext) { + throw new NoSuchElementException("End of stream") + } + gotNext = false + nextValue + } +} \ No newline at end of file diff --git a/core/src/test/scala/spark/util/NextIteratorSuite.scala b/core/src/test/scala/spark/util/NextIteratorSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..ed5b36da73fd16dcb8844bfcc68e5728f7406355 --- /dev/null +++ b/core/src/test/scala/spark/util/NextIteratorSuite.scala @@ -0,0 +1,68 @@ +package spark.util + +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import scala.collection.mutable.Buffer +import java.util.NoSuchElementException + +class NextIteratorSuite extends FunSuite with ShouldMatchers { + test("one iteration") { + val i = new StubIterator(Buffer(1)) + i.hasNext should be === true + i.next should be === 1 + i.hasNext should be === false + intercept[NoSuchElementException] { i.next() } + } + + test("two iterations") { + val i = new StubIterator(Buffer(1, 2)) + i.hasNext should be === true + i.next should be === 1 + i.hasNext should be === true + i.next should be === 2 + i.hasNext should be === false + intercept[NoSuchElementException] { i.next() } + } + + test("empty iteration") { + val i = new StubIterator(Buffer()) + i.hasNext should be === false + intercept[NoSuchElementException] { i.next() } + } + + test("close is called once for empty iterations") { + val i = new StubIterator(Buffer()) + i.hasNext should be === false + i.hasNext should be === false + i.closeCalled should be === 1 + } + + test("close is called once for non-empty iterations") { + val i = new StubIterator(Buffer(1, 2)) + i.next should be === 1 + i.next should be === 2 + // close isn't called until we check for the next element + i.closeCalled should be === 0 + i.hasNext should be === false + i.closeCalled should be === 1 + i.hasNext should be === false + i.closeCalled should be === 1 + } + + class StubIterator(ints: Buffer[Int]) extends NextIterator[Int] { + var closeCalled = 0 + + override def getNext() = { + if (ints.size == 0) { + finished = true + 0 + } else { + ints.remove(0) + } + } + + override def close() { + closeCalled += 1 + } + } +} diff --git a/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala index 4af839ad7f03d1e75dd19715a2cb0785a9a91628..1408af0afa5018545fd9ec1db61df9182499d095 100644 --- a/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/dstream/SocketInputDStream.scala @@ -2,6 +2,7 @@ package spark.streaming.dstream import spark.streaming.StreamingContext import spark.storage.StorageLevel +import spark.util.NextIterator import java.io._ import java.net.Socket @@ -59,45 +60,18 @@ object SocketReceiver { */ def bytesToLines(inputStream: InputStream): Iterator[String] = { val dataInputStream = new BufferedReader(new InputStreamReader(inputStream, "UTF-8")) - - val iterator = new Iterator[String] { - var gotNext = false - var finished = false - var nextValue: String = null - - private def getNext() { - try { - nextValue = dataInputStream.readLine() - if (nextValue == null) { - finished = true - } - } - gotNext = true - } - - override def hasNext: Boolean = { - if (!finished) { - if (!gotNext) { - getNext() - if (finished) { - dataInputStream.close() - } - } + new NextIterator[String] { + protected override def getNext() = { + val nextValue = dataInputStream.readLine() + if (nextValue == null) { + finished = true } - !finished + nextValue } - override def next(): String = { - if (finished) { - throw new NoSuchElementException("End of stream") - } - if (!gotNext) { - getNext() - } - gotNext = false - nextValue + protected override def close() { + dataInputStream.close() } } - iterator } }