diff --git a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java index 965c4ae3076678aadf164c306751fd61596f6805..ae91bc9cfdd083cb9c2d3060cd66836fbc32573c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java @@ -168,7 +168,7 @@ public class TransportContext { TransportResponseHandler responseHandler = new TransportResponseHandler(channel); TransportClient client = new TransportClient(channel, responseHandler); TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, - rpcHandler); + rpcHandler, conf.maxChunksBeingTransferred()); return new TransportChannelHandler(client, responseHandler, requestHandler, conf.connectionTimeoutMs(), closeIdleConnections); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index 85ca2f1728e6a844a3c8bada187442a2be51261d..0f6a8824d95e50947e1304633794577b4f6a77ad 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -25,6 +25,8 @@ import java.util.concurrent.atomic.AtomicLong; import com.google.common.base.Preconditions; import io.netty.channel.Channel; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -53,6 +55,9 @@ public class OneForOneStreamManager extends StreamManager { // that the caller only requests each chunk one at a time, in order. int curChunk = 0; + // Used to keep track of the number of chunks being transferred and not finished yet. + volatile long chunksBeingTransferred = 0L; + StreamState(String appId, Iterator<ManagedBuffer> buffers) { this.appId = appId; this.buffers = Preconditions.checkNotNull(buffers); @@ -96,18 +101,25 @@ public class OneForOneStreamManager extends StreamManager { @Override public ManagedBuffer openStream(String streamChunkId) { - String[] array = streamChunkId.split("_"); - assert array.length == 2: - "Stream id and chunk index should be specified when open stream for fetching block."; - long streamId = Long.valueOf(array[0]); - int chunkIndex = Integer.valueOf(array[1]); - return getChunk(streamId, chunkIndex); + Pair<Long, Integer> streamChunkIdPair = parseStreamChunkId(streamChunkId); + return getChunk(streamChunkIdPair.getLeft(), streamChunkIdPair.getRight()); } public static String genStreamChunkId(long streamId, int chunkId) { return String.format("%d_%d", streamId, chunkId); } + // Parse streamChunkId to be stream id and chunk id. This is used when fetch remote chunk as a + // stream. + public static Pair<Long, Integer> parseStreamChunkId(String streamChunkId) { + String[] array = streamChunkId.split("_"); + assert array.length == 2: + "Stream id and chunk index should be specified."; + long streamId = Long.valueOf(array[0]); + int chunkIndex = Integer.valueOf(array[1]); + return ImmutablePair.of(streamId, chunkIndex); + } + @Override public void connectionTerminated(Channel channel) { // Close all streams which have been associated with the channel. @@ -139,6 +151,42 @@ public class OneForOneStreamManager extends StreamManager { } } + @Override + public void chunkBeingSent(long streamId) { + StreamState streamState = streams.get(streamId); + if (streamState != null) { + streamState.chunksBeingTransferred++; + } + + } + + @Override + public void streamBeingSent(String streamId) { + chunkBeingSent(parseStreamChunkId(streamId).getLeft()); + } + + @Override + public void chunkSent(long streamId) { + StreamState streamState = streams.get(streamId); + if (streamState != null) { + streamState.chunksBeingTransferred--; + } + } + + @Override + public void streamSent(String streamId) { + chunkSent(OneForOneStreamManager.parseStreamChunkId(streamId).getLeft()); + } + + @Override + public long chunksBeingTransferred() { + long sum = 0L; + for (StreamState streamState: streams.values()) { + sum += streamState.chunksBeingTransferred; + } + return sum; + } + /** * Registers a stream of ManagedBuffers which are served as individual chunks one at a time to * callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java index 07f161a29cfb8c8968423114a94e048861b86b6c..c535295831606fd8db8b072910be63473d8decea 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -83,4 +83,31 @@ public abstract class StreamManager { */ public void checkAuthorization(TransportClient client, long streamId) { } + /** + * Return the number of chunks being transferred and not finished yet in this StreamManager. + */ + public long chunksBeingTransferred() { + return 0; + } + + /** + * Called when start sending a chunk. + */ + public void chunkBeingSent(long streamId) { } + + /** + * Called when start sending a stream. + */ + public void streamBeingSent(String streamId) { } + + /** + * Called when a chunk is successfully sent. + */ + public void chunkSent(long streamId) { } + + /** + * Called when a stream is successfully sent. + */ + public void streamSent(String streamId) { } + } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 8193bc137610255592920c7419eae9f4ed8db391..e94453578e6b080db0ebbb4297b24a9259647469 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -22,6 +22,7 @@ import java.nio.ByteBuffer; import com.google.common.base.Throwables; import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -65,14 +66,19 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { /** Returns each chunk part of a stream. */ private final StreamManager streamManager; + /** The max number of chunks being transferred and not finished yet. */ + private final long maxChunksBeingTransferred; + public TransportRequestHandler( Channel channel, TransportClient reverseClient, - RpcHandler rpcHandler) { + RpcHandler rpcHandler, + Long maxChunksBeingTransferred) { this.channel = channel; this.reverseClient = reverseClient; this.rpcHandler = rpcHandler; this.streamManager = rpcHandler.getStreamManager(); + this.maxChunksBeingTransferred = maxChunksBeingTransferred; } @Override @@ -117,7 +123,13 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { logger.trace("Received req from {} to fetch block {}", getRemoteAddress(channel), req.streamChunkId); } - + long chunksBeingTransferred = streamManager.chunksBeingTransferred(); + if (chunksBeingTransferred >= maxChunksBeingTransferred) { + logger.warn("The number of chunks being transferred {} is above {}, close the connection.", + chunksBeingTransferred, maxChunksBeingTransferred); + channel.close(); + return; + } ManagedBuffer buf; try { streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId); @@ -130,10 +142,25 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { return; } - respond(new ChunkFetchSuccess(req.streamChunkId, buf)); + streamManager.chunkBeingSent(req.streamChunkId.streamId); + respond(new ChunkFetchSuccess(req.streamChunkId, buf)).addListener(future -> { + streamManager.chunkSent(req.streamChunkId.streamId); + }); } private void processStreamRequest(final StreamRequest req) { + if (logger.isTraceEnabled()) { + logger.trace("Received req from {} to fetch stream {}", getRemoteAddress(channel), + req.streamId); + } + + long chunksBeingTransferred = streamManager.chunksBeingTransferred(); + if (chunksBeingTransferred >= maxChunksBeingTransferred) { + logger.warn("The number of chunks being transferred {} is above {}, close the connection.", + chunksBeingTransferred, maxChunksBeingTransferred); + channel.close(); + return; + } ManagedBuffer buf; try { buf = streamManager.openStream(req.streamId); @@ -145,7 +172,10 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { } if (buf != null) { - respond(new StreamResponse(req.streamId, buf.size(), buf)); + streamManager.streamBeingSent(req.streamId); + respond(new StreamResponse(req.streamId, buf.size(), buf)).addListener(future -> { + streamManager.streamSent(req.streamId); + }); } else { respond(new StreamFailure(req.streamId, String.format( "Stream '%s' was not found.", req.streamId))); @@ -187,9 +217,9 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { * Responds to a single message with some Encodable object. If a failure occurs while sending, * it will be logged and the channel closed. */ - private void respond(Encodable result) { + private ChannelFuture respond(Encodable result) { SocketAddress remoteAddress = channel.remoteAddress(); - channel.writeAndFlush(result).addListener(future -> { + return channel.writeAndFlush(result).addListener(future -> { if (future.isSuccess()) { logger.trace("Sent result {} to client {}", result, remoteAddress); } else { diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index a25078e262efb9ed8cadc6b1caaa72072d6ce6c4..ea52e9fe6c1c188e5f84620be0cb440e12ae5e1a 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -257,4 +257,10 @@ public class TransportConf { return CryptoUtils.toCryptoConf("spark.network.crypto.config.", conf.getAll()); } + /** + * The max number of chunks allowed to being transferred at the same time on shuffle service. + */ + public long maxChunksBeingTransferred() { + return conf.getLong("spark.shuffle.maxChunksBeingTransferred", Long.MAX_VALUE); + } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..1fb987a8a7aa76dfef950a4bb89330ea0ca53d2e --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.util.ArrayList; +import java.util.List; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelPromise; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import org.junit.Test; + +import static org.mockito.Mockito.*; + +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.*; +import org.apache.spark.network.server.NoOpRpcHandler; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportRequestHandler; + +public class TransportRequestHandlerSuite { + + @Test + public void handleFetchRequestAndStreamRequest() throws Exception { + RpcHandler rpcHandler = new NoOpRpcHandler(); + OneForOneStreamManager streamManager = (OneForOneStreamManager) (rpcHandler.getStreamManager()); + Channel channel = mock(Channel.class); + List<Pair<Object, ExtendedChannelPromise>> responseAndPromisePairs = + new ArrayList<>(); + when(channel.writeAndFlush(any())) + .thenAnswer(invocationOnMock0 -> { + Object response = invocationOnMock0.getArguments()[0]; + ExtendedChannelPromise channelFuture = new ExtendedChannelPromise(channel); + responseAndPromisePairs.add(ImmutablePair.of(response, channelFuture)); + return channelFuture; + }); + + // Prepare the stream. + List<ManagedBuffer> managedBuffers = new ArrayList<>(); + managedBuffers.add(new TestManagedBuffer(10)); + managedBuffers.add(new TestManagedBuffer(20)); + managedBuffers.add(new TestManagedBuffer(30)); + managedBuffers.add(new TestManagedBuffer(40)); + long streamId = streamManager.registerStream("test-app", managedBuffers.iterator()); + streamManager.registerChannel(channel, streamId); + TransportClient reverseClient = mock(TransportClient.class); + TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient, + rpcHandler, 2L); + + RequestMessage request0 = new ChunkFetchRequest(new StreamChunkId(streamId, 0)); + requestHandler.handle(request0); + assert responseAndPromisePairs.size() == 1; + assert responseAndPromisePairs.get(0).getLeft() instanceof ChunkFetchSuccess; + assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(0).getLeft())).body() == + managedBuffers.get(0); + + RequestMessage request1 = new ChunkFetchRequest(new StreamChunkId(streamId, 1)); + requestHandler.handle(request1); + assert responseAndPromisePairs.size() == 2; + assert responseAndPromisePairs.get(1).getLeft() instanceof ChunkFetchSuccess; + assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(1).getLeft())).body() == + managedBuffers.get(1); + + // Finish flushing the response for request0. + responseAndPromisePairs.get(0).getRight().finish(true); + + RequestMessage request2 = new StreamRequest(String.format("%d_%d", streamId, 2)); + requestHandler.handle(request2); + assert responseAndPromisePairs.size() == 3; + assert responseAndPromisePairs.get(2).getLeft() instanceof StreamResponse; + assert ((StreamResponse) (responseAndPromisePairs.get(2).getLeft())).body() == + managedBuffers.get(2); + + // Request3 will trigger the close of channel, because the number of max chunks being + // transferred is 2; + RequestMessage request3 = new StreamRequest(String.format("%d_%d", streamId, 3)); + requestHandler.handle(request3); + verify(channel, times(1)).close(); + assert responseAndPromisePairs.size() == 3; + } + + private class ExtendedChannelPromise extends DefaultChannelPromise { + + private List<GenericFutureListener> listeners = new ArrayList<>(); + private boolean success; + + public ExtendedChannelPromise(Channel channel) { + super(channel); + success = false; + } + + @Override + public ChannelPromise addListener( + GenericFutureListener<? extends Future<? super Void>> listener) { + listeners.add(listener); + return super.addListener(listener); + } + + @Override + public boolean isSuccess() { + return success; + } + + public void finish(boolean success) { + this.success = success; + listeners.forEach(listener -> { + try { + listener.operationComplete(this); + } catch (Exception e) { } + }); + } + } +} diff --git a/docs/configuration.md b/docs/configuration.md index d3df923c42690177c89058baaab0fc8e42d114d0..f4b6f46db5b66a743cb9d005eb6d59c404b452a3 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -631,6 +631,13 @@ Apart from these, the following properties are also available, and may be useful Max number of entries to keep in the index cache of the shuffle service. </td> </tr> +<tr> + <td><code>spark.shuffle.maxChunksBeingTransferred</code></td> + <td>Long.MAX_VALUE</td> + <td> + The max number of chunks allowed to being transferred at the same time on shuffle service. + </td> +</tr> <tr> <td><code>spark.shuffle.sort.bypassMergeThreshold</code></td> <td>200</td>