diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala index cd5b7d57f32f58c3c1a5e20e1dcc011d0650a7e1..d1451bc2124c581eff01dfae5277612ea5c995c7 100644 --- a/core/src/main/scala/spark/network/Connection.scala +++ b/core/src/main/scala/spark/network/Connection.scala @@ -198,7 +198,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { outbox.synchronized { outbox.addMessage(message) if (channel.isConnected) { - changeConnectionKeyInterest(SelectionKey.OP_WRITE) + changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ) } } } @@ -219,7 +219,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { def finishConnect() { try { channel.finishConnect - changeConnectionKeyInterest(SelectionKey.OP_WRITE) + changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ) logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") } catch { case e: Exception => { @@ -239,8 +239,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { currentBuffers ++= chunk.buffers } case None => { - changeConnectionKeyInterest(0) - /*key.interestOps(0)*/ + changeConnectionKeyInterest(SelectionKey.OP_READ) return } } @@ -267,6 +266,23 @@ extends Connection(SocketChannel.open, selector_, remoteId_) { } } } + + override def read() { + // We don't expect the other side to send anything; so, we just read to detect an error or EOF. + try { + val length = channel.read(ByteBuffer.allocate(1)) + if (length == -1) { // EOF + close() + } else if (length > 0) { + logWarning("Unexpected data read from SendingConnection to " + remoteConnectionManagerId) + } + } catch { + case e: Exception => + logError("Exception while reading SendingConnection to " + remoteConnectionManagerId, e) + callOnExceptionCallback(e) + close() + } + } } diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index 0e2585daa434cdad2c8deb14bd088072a69dfe31..caa4ba3a3705af5587099951c1914a03663b5ea8 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -217,6 +217,27 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter assert(grouped.collect.size === 1) } } + + test("recover from node failures with replication") { + import DistributedSuite.{markNodeIfIdentity, failOnMarkedIdentity} + DistributedSuite.amMaster = true + // Using more than two nodes so we don't have a symmetric communication pattern and might + // cache a partially correct list of peers. + sc = new SparkContext("local-cluster[3,1,512]", "test") + for (i <- 1 to 3) { + val data = sc.parallelize(Seq(true, false, false, false), 4) + data.persist(StorageLevel.MEMORY_ONLY_2) + + assert(data.count === 4) + assert(data.map(markNodeIfIdentity).collect.size === 4) + assert(data.map(failOnMarkedIdentity).collect.size === 4) + + // Create a new replicated RDD to make sure that cached peer information doesn't cause + // problems. + val data2 = sc.parallelize(Seq(true, true), 2).persist(StorageLevel.MEMORY_ONLY_2) + assert(data2.count === 2) + } + } } object DistributedSuite {