Skip to content
Snippets Groups Projects
Commit 4032beba authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Merge pull request #521 from stephenh/earlyclose

Close the reader in HadoopRDD as soon as iteration end.
parents 3c97276a e7f1a69c
No related branches found
No related tags found
No related merge requests found
......@@ -16,6 +16,7 @@ import org.apache.hadoop.mapred.Reporter
import org.apache.hadoop.util.ReflectionUtils
import spark.{Dependency, Logging, Partition, RDD, SerializableWritable, SparkContext, TaskContext}
import spark.util.NextIterator
/**
......@@ -62,7 +63,7 @@ class HadoopRDD[K, V](
.asInstanceOf[InputFormat[K, V]]
}
override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] {
override def compute(theSplit: Partition, context: TaskContext) = new NextIterator[(K, V)] {
val split = theSplit.asInstanceOf[HadoopPartition]
var reader: RecordReader[K, V] = null
......@@ -71,38 +72,22 @@ class HadoopRDD[K, V](
reader = fmt.getRecordReader(split.inputSplit.value, conf, Reporter.NULL)
// Register an on-task-completion callback to close the input stream.
context.addOnCompleteCallback{ () => close() }
context.addOnCompleteCallback{ () => closeIfNeeded() }
val key: K = reader.createKey()
val value: V = reader.createValue()
var gotNext = false
var finished = false
override def hasNext: Boolean = {
if (!gotNext) {
try {
finished = !reader.next(key, value)
} catch {
case eof: EOFException =>
finished = true
}
gotNext = true
}
!finished
}
override def next: (K, V) = {
if (!gotNext) {
override def getNext() = {
try {
finished = !reader.next(key, value)
} catch {
case eof: EOFException =>
finished = true
}
if (finished) {
throw new NoSuchElementException("End of stream")
}
gotNext = false
(key, value)
}
private def close() {
override def close() {
try {
reader.close()
} catch {
......
......@@ -72,40 +72,18 @@ trait DeserializationStream {
* Read the elements of this stream through an iterator. This can only be called once, as
* reading each element will consume data from the input source.
*/
def asIterator: Iterator[Any] = new Iterator[Any] {
var gotNext = false
var finished = false
var nextValue: Any = null
private def getNext() {
def asIterator: Iterator[Any] = new spark.util.NextIterator[Any] {
override protected def getNext() = {
try {
nextValue = readObject[Any]()
readObject[Any]()
} catch {
case eof: EOFException =>
finished = true
}
gotNext = true
}
override def hasNext: Boolean = {
if (!gotNext) {
getNext()
}
if (finished) {
close()
}
!finished
}
override def next(): Any = {
if (!gotNext) {
getNext()
}
if (finished) {
throw new NoSuchElementException("End of stream")
}
gotNext = false
nextValue
override protected def close() {
DeserializationStream.this.close()
}
}
}
package spark.util
/** Provides a basic/boilerplate Iterator implementation. */
private[spark] abstract class NextIterator[U] extends Iterator[U] {
private var gotNext = false
private var nextValue: U = _
private var closed = false
protected var finished = false
/**
* Method for subclasses to implement to provide the next element.
*
* If no next element is available, the subclass should set `finished`
* to `true` and may return any value (it will be ignored).
*
* This convention is required because `null` may be a valid value,
* and using `Option` seems like it might create unnecessary Some/None
* instances, given some iterators might be called in a tight loop.
*
* @return U, or set 'finished' when done
*/
protected def getNext(): U
/**
* Method for subclasses to implement when all elements have been successfully
* iterated, and the iteration is done.
*
* <b>Note:</b> `NextIterator` cannot guarantee that `close` will be
* called because it has no control over what happens when an exception
* happens in the user code that is calling hasNext/next.
*
* Ideally you should have another try/catch, as in HadoopRDD, that
* ensures any resources are closed should iteration fail.
*/
protected def close()
/**
* Calls the subclass-defined close method, but only once.
*
* Usually calling `close` multiple times should be fine, but historically
* there have been issues with some InputFormats throwing exceptions.
*/
def closeIfNeeded() {
if (!closed) {
close()
closed = true
}
}
override def hasNext: Boolean = {
if (!finished) {
if (!gotNext) {
nextValue = getNext()
if (finished) {
closeIfNeeded()
}
gotNext = true
}
}
!finished
}
override def next(): U = {
if (!hasNext) {
throw new NoSuchElementException("End of stream")
}
gotNext = false
nextValue
}
}
\ No newline at end of file
package spark.util
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
import scala.collection.mutable.Buffer
import java.util.NoSuchElementException
class NextIteratorSuite extends FunSuite with ShouldMatchers {
test("one iteration") {
val i = new StubIterator(Buffer(1))
i.hasNext should be === true
i.next should be === 1
i.hasNext should be === false
intercept[NoSuchElementException] { i.next() }
}
test("two iterations") {
val i = new StubIterator(Buffer(1, 2))
i.hasNext should be === true
i.next should be === 1
i.hasNext should be === true
i.next should be === 2
i.hasNext should be === false
intercept[NoSuchElementException] { i.next() }
}
test("empty iteration") {
val i = new StubIterator(Buffer())
i.hasNext should be === false
intercept[NoSuchElementException] { i.next() }
}
test("close is called once for empty iterations") {
val i = new StubIterator(Buffer())
i.hasNext should be === false
i.hasNext should be === false
i.closeCalled should be === 1
}
test("close is called once for non-empty iterations") {
val i = new StubIterator(Buffer(1, 2))
i.next should be === 1
i.next should be === 2
// close isn't called until we check for the next element
i.closeCalled should be === 0
i.hasNext should be === false
i.closeCalled should be === 1
i.hasNext should be === false
i.closeCalled should be === 1
}
class StubIterator(ints: Buffer[Int]) extends NextIterator[Int] {
var closeCalled = 0
override def getNext() = {
if (ints.size == 0) {
finished = true
0
} else {
ints.remove(0)
}
}
override def close() {
closeCalled += 1
}
}
}
......@@ -2,6 +2,7 @@ package spark.streaming.dstream
import spark.streaming.StreamingContext
import spark.storage.StorageLevel
import spark.util.NextIterator
import java.io._
import java.net.Socket
......@@ -59,45 +60,18 @@ object SocketReceiver {
*/
def bytesToLines(inputStream: InputStream): Iterator[String] = {
val dataInputStream = new BufferedReader(new InputStreamReader(inputStream, "UTF-8"))
val iterator = new Iterator[String] {
var gotNext = false
var finished = false
var nextValue: String = null
private def getNext() {
try {
nextValue = dataInputStream.readLine()
if (nextValue == null) {
finished = true
}
}
gotNext = true
}
override def hasNext: Boolean = {
if (!finished) {
if (!gotNext) {
getNext()
if (finished) {
dataInputStream.close()
}
}
new NextIterator[String] {
protected override def getNext() = {
val nextValue = dataInputStream.readLine()
if (nextValue == null) {
finished = true
}
!finished
nextValue
}
override def next(): String = {
if (finished) {
throw new NoSuchElementException("End of stream")
}
if (!gotNext) {
getNext()
}
gotNext = false
nextValue
protected override def close() {
dataInputStream.close()
}
}
iterator
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment