Skip to content
Snippets Groups Projects
Commit 7cfa4c6b authored by Marcelo Vanzin's avatar Marcelo Vanzin
Browse files

[SPARK-11865][NETWORK] Avoid returning inactive client in TransportClientFactory.

There's a very narrow race here where it would be possible for the timeout handler
to close a channel after the client factory verified that the channel was still
active. This change makes sure the client is marked as being recently in use so
that the timeout handler does not close it until a new timeout cycle elapses.

Author: Marcelo Vanzin <vanzin@cloudera.com>

Closes #9853 from vanzin/SPARK-11865.
parent 242be7da
No related branches found
No related tags found
No related merge requests found
...@@ -73,10 +73,12 @@ public class TransportClient implements Closeable { ...@@ -73,10 +73,12 @@ public class TransportClient implements Closeable {
private final Channel channel; private final Channel channel;
private final TransportResponseHandler handler; private final TransportResponseHandler handler;
@Nullable private String clientId; @Nullable private String clientId;
private volatile boolean timedOut;
public TransportClient(Channel channel, TransportResponseHandler handler) { public TransportClient(Channel channel, TransportResponseHandler handler) {
this.channel = Preconditions.checkNotNull(channel); this.channel = Preconditions.checkNotNull(channel);
this.handler = Preconditions.checkNotNull(handler); this.handler = Preconditions.checkNotNull(handler);
this.timedOut = false;
} }
public Channel getChannel() { public Channel getChannel() {
...@@ -84,7 +86,7 @@ public class TransportClient implements Closeable { ...@@ -84,7 +86,7 @@ public class TransportClient implements Closeable {
} }
public boolean isActive() { public boolean isActive() {
return channel.isOpen() || channel.isActive(); return !timedOut && (channel.isOpen() || channel.isActive());
} }
public SocketAddress getSocketAddress() { public SocketAddress getSocketAddress() {
...@@ -263,6 +265,11 @@ public class TransportClient implements Closeable { ...@@ -263,6 +265,11 @@ public class TransportClient implements Closeable {
} }
} }
/** Mark this channel as having timed out. */
public void timeOut() {
this.timedOut = true;
}
@Override @Override
public void close() { public void close() {
// close is a local operation and should finish with milliseconds; timeout just to be safe // close is a local operation and should finish with milliseconds; timeout just to be safe
......
...@@ -136,8 +136,19 @@ public class TransportClientFactory implements Closeable { ...@@ -136,8 +136,19 @@ public class TransportClientFactory implements Closeable {
TransportClient cachedClient = clientPool.clients[clientIndex]; TransportClient cachedClient = clientPool.clients[clientIndex];
if (cachedClient != null && cachedClient.isActive()) { if (cachedClient != null && cachedClient.isActive()) {
logger.trace("Returning cached connection to {}: {}", address, cachedClient); // Make sure that the channel will not timeout by updating the last use time of the
return cachedClient; // handler. Then check that the client is still alive, in case it timed out before
// this code was able to update things.
TransportChannelHandler handler = cachedClient.getChannel().pipeline()
.get(TransportChannelHandler.class);
synchronized (handler) {
handler.getResponseHandler().updateTimeOfLastRequest();
}
if (cachedClient.isActive()) {
logger.trace("Returning cached connection to {}: {}", address, cachedClient);
return cachedClient;
}
} }
// If we reach here, we don't have an existing connection open. Let's create a new one. // If we reach here, we don't have an existing connection open. Let's create a new one.
......
...@@ -71,7 +71,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> { ...@@ -71,7 +71,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
} }
public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) {
timeOfLastRequestNs.set(System.nanoTime()); updateTimeOfLastRequest();
outstandingFetches.put(streamChunkId, callback); outstandingFetches.put(streamChunkId, callback);
} }
...@@ -80,7 +80,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> { ...@@ -80,7 +80,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
} }
public void addRpcRequest(long requestId, RpcResponseCallback callback) { public void addRpcRequest(long requestId, RpcResponseCallback callback) {
timeOfLastRequestNs.set(System.nanoTime()); updateTimeOfLastRequest();
outstandingRpcs.put(requestId, callback); outstandingRpcs.put(requestId, callback);
} }
...@@ -227,4 +227,9 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> { ...@@ -227,4 +227,9 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
return timeOfLastRequestNs.get(); return timeOfLastRequestNs.get();
} }
/** Updates the time of the last request to the current system time. */
public void updateTimeOfLastRequest() {
timeOfLastRequestNs.set(System.nanoTime());
}
} }
...@@ -116,20 +116,32 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message ...@@ -116,20 +116,32 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler<Message
// there are outstanding requests, we also do a secondary consistency check to ensure // there are outstanding requests, we also do a secondary consistency check to ensure
// there's no race between the idle timeout and incrementing the numOutstandingRequests // there's no race between the idle timeout and incrementing the numOutstandingRequests
// (see SPARK-7003). // (see SPARK-7003).
boolean isActuallyOverdue = //
System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; // To avoid a race between TransportClientFactory.createClient() and this code which could
if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) { // result in an inactive client being returned, this needs to run in a synchronized block.
if (responseHandler.numOutstandingRequests() > 0) { synchronized (this) {
String address = NettyUtils.getRemoteAddress(ctx.channel()); boolean isActuallyOverdue =
logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs;
"requests. Assuming connection is dead; please adjust spark.network.timeout if this " + if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) {
"is wrong.", address, requestTimeoutNs / 1000 / 1000); if (responseHandler.numOutstandingRequests() > 0) {
ctx.close(); String address = NettyUtils.getRemoteAddress(ctx.channel());
} else if (closeIdleConnections) { logger.error("Connection to {} has been quiet for {} ms while there are outstanding " +
// While CloseIdleConnections is enable, we also close idle connection "requests. Assuming connection is dead; please adjust spark.network.timeout if this " +
ctx.close(); "is wrong.", address, requestTimeoutNs / 1000 / 1000);
client.timeOut();
ctx.close();
} else if (closeIdleConnections) {
// While CloseIdleConnections is enable, we also close idle connection
client.timeOut();
ctx.close();
}
} }
} }
} }
} }
public TransportResponseHandler getResponseHandler() {
return responseHandler;
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment