From 16a0789e10d2ac714e7c623b026c4a58ca9678d6 Mon Sep 17 00:00:00 2001
From: Charles Reiss <charles@eecs.berkeley.edu>
Date: Tue, 29 Jan 2013 17:09:53 -0800
Subject: [PATCH] Remember ConnectionManagerId used to initiate
 SendingConnections.

This prevents ConnectionManager from getting confused if a machine
has multiple host names and the one getHostName() finds happens
not to be the one that was passed from, e.g., the BlockManagerMaster.
---
 .../src/main/scala/spark/network/Connection.scala | 15 +++++++++++----
 .../scala/spark/network/ConnectionManager.scala   |  3 ++-
 2 files changed, 13 insertions(+), 5 deletions(-)

diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala
index c193bf7c8d..cd5b7d57f3 100644
--- a/core/src/main/scala/spark/network/Connection.scala
+++ b/core/src/main/scala/spark/network/Connection.scala
@@ -12,7 +12,14 @@ import java.net._
 
 
 private[spark]
-abstract class Connection(val channel: SocketChannel, val selector: Selector) extends Logging {
+abstract class Connection(val channel: SocketChannel, val selector: Selector,
+                          val remoteConnectionManagerId: ConnectionManagerId) extends Logging {
+  def this(channel_ : SocketChannel, selector_ : Selector) = {
+    this(channel_, selector_,
+         ConnectionManagerId.fromSocketAddress(
+            channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
+         ))
+  }
 
   channel.configureBlocking(false)
   channel.socket.setTcpNoDelay(true)
@@ -25,7 +32,6 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
   var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
 
   val remoteAddress = getRemoteAddress()
-  val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress)
 
   def key() = channel.keyFor(selector)
 
@@ -103,8 +109,9 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
 }
 
 
-private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector) 
-extends Connection(SocketChannel.open, selector_) {
+private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
+                                       remoteId_ : ConnectionManagerId)
+extends Connection(SocketChannel.open, selector_, remoteId_) {
 
   class Outbox(fair: Int = 0) {
     val messages = new Queue[Message]()
diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala
index 2ecd14f536..c7f226044d 100644
--- a/core/src/main/scala/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/spark/network/ConnectionManager.scala
@@ -299,7 +299,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
   private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
     def startNewConnection(): SendingConnection = {
       val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
-      val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId, new SendingConnection(inetSocketAddress, selector))
+      val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId,
+          new SendingConnection(inetSocketAddress, selector, connectionManagerId))
       newConnection   
     }
     val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress)
-- 
GitLab