From 799e13161e89f1ea96cb1bc7b507a05af2e89cd0 Mon Sep 17 00:00:00 2001 From: jinxing <jinxing6042@126.com> Date: Tue, 25 Jul 2017 20:52:07 +0800 Subject: [PATCH] [SPARK-21175] Reject OpenBlocks when memory shortage on shuffle service. ## What changes were proposed in this pull request? A shuffle service can serves blocks from multiple apps/tasks. Thus the shuffle service can suffers high memory usage when lots of shuffle-reads happen at the same time. In my cluster, OOM always happens on shuffle service. Analyzing heap dump, memory cost by Netty(ChannelOutboundBufferEntry) can be up to 2~3G. It might make sense to reject "open blocks" request when memory usage is high on shuffle service. https://github.com/apache/spark/commit/93dd0c518d040155b04e5ab258c5835aec7776fc and https://github.com/apache/spark/commit/85c6ce61930490e2247fb4b0e22dfebbb8b6a1ee tried to alleviate the memory pressure on shuffle service but cannot solve the root cause. This pr proposes to control currency of shuffle read. ## How was this patch tested? Added unit test. Author: jinxing <jinxing6042@126.com> Closes #18388 from jinxing64/SPARK-21175. --- .../spark/network/TransportContext.java | 2 +- .../server/OneForOneStreamManager.java | 60 +++++++- .../spark/network/server/StreamManager.java | 27 ++++ .../server/TransportRequestHandler.java | 42 +++++- .../spark/network/util/TransportConf.java | 6 + .../network/TransportRequestHandlerSuite.java | 134 ++++++++++++++++++ docs/configuration.md | 7 + 7 files changed, 265 insertions(+), 13 deletions(-) create mode 100644 common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java index 965c4ae307..ae91bc9cfd 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java @@ -168,7 +168,7 @@ public class TransportContext { TransportResponseHandler responseHandler = new TransportResponseHandler(channel); TransportClient client = new TransportClient(channel, responseHandler); TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, - rpcHandler); + rpcHandler, conf.maxChunksBeingTransferred()); return new TransportChannelHandler(client, responseHandler, requestHandler, conf.connectionTimeoutMs(), closeIdleConnections); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index 85ca2f1728..0f6a8824d9 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -25,6 +25,8 @@ import java.util.concurrent.atomic.AtomicLong; import com.google.common.base.Preconditions; import io.netty.channel.Channel; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -53,6 +55,9 @@ public class OneForOneStreamManager extends StreamManager { // that the caller only requests each chunk one at a time, in order. int curChunk = 0; + // Used to keep track of the number of chunks being transferred and not finished yet. + volatile long chunksBeingTransferred = 0L; + StreamState(String appId, Iterator<ManagedBuffer> buffers) { this.appId = appId; this.buffers = Preconditions.checkNotNull(buffers); @@ -96,18 +101,25 @@ public class OneForOneStreamManager extends StreamManager { @Override public ManagedBuffer openStream(String streamChunkId) { - String[] array = streamChunkId.split("_"); - assert array.length == 2: - "Stream id and chunk index should be specified when open stream for fetching block."; - long streamId = Long.valueOf(array[0]); - int chunkIndex = Integer.valueOf(array[1]); - return getChunk(streamId, chunkIndex); + Pair<Long, Integer> streamChunkIdPair = parseStreamChunkId(streamChunkId); + return getChunk(streamChunkIdPair.getLeft(), streamChunkIdPair.getRight()); } public static String genStreamChunkId(long streamId, int chunkId) { return String.format("%d_%d", streamId, chunkId); } + // Parse streamChunkId to be stream id and chunk id. This is used when fetch remote chunk as a + // stream. + public static Pair<Long, Integer> parseStreamChunkId(String streamChunkId) { + String[] array = streamChunkId.split("_"); + assert array.length == 2: + "Stream id and chunk index should be specified."; + long streamId = Long.valueOf(array[0]); + int chunkIndex = Integer.valueOf(array[1]); + return ImmutablePair.of(streamId, chunkIndex); + } + @Override public void connectionTerminated(Channel channel) { // Close all streams which have been associated with the channel. @@ -139,6 +151,42 @@ public class OneForOneStreamManager extends StreamManager { } } + @Override + public void chunkBeingSent(long streamId) { + StreamState streamState = streams.get(streamId); + if (streamState != null) { + streamState.chunksBeingTransferred++; + } + + } + + @Override + public void streamBeingSent(String streamId) { + chunkBeingSent(parseStreamChunkId(streamId).getLeft()); + } + + @Override + public void chunkSent(long streamId) { + StreamState streamState = streams.get(streamId); + if (streamState != null) { + streamState.chunksBeingTransferred--; + } + } + + @Override + public void streamSent(String streamId) { + chunkSent(OneForOneStreamManager.parseStreamChunkId(streamId).getLeft()); + } + + @Override + public long chunksBeingTransferred() { + long sum = 0L; + for (StreamState streamState: streams.values()) { + sum += streamState.chunksBeingTransferred; + } + return sum; + } + /** * Registers a stream of ManagedBuffers which are served as individual chunks one at a time to * callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java index 07f161a29c..c535295831 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -83,4 +83,31 @@ public abstract class StreamManager { */ public void checkAuthorization(TransportClient client, long streamId) { } + /** + * Return the number of chunks being transferred and not finished yet in this StreamManager. + */ + public long chunksBeingTransferred() { + return 0; + } + + /** + * Called when start sending a chunk. + */ + public void chunkBeingSent(long streamId) { } + + /** + * Called when start sending a stream. + */ + public void streamBeingSent(String streamId) { } + + /** + * Called when a chunk is successfully sent. + */ + public void chunkSent(long streamId) { } + + /** + * Called when a stream is successfully sent. + */ + public void streamSent(String streamId) { } + } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 8193bc1376..e94453578e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -22,6 +22,7 @@ import java.nio.ByteBuffer; import com.google.common.base.Throwables; import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -65,14 +66,19 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { /** Returns each chunk part of a stream. */ private final StreamManager streamManager; + /** The max number of chunks being transferred and not finished yet. */ + private final long maxChunksBeingTransferred; + public TransportRequestHandler( Channel channel, TransportClient reverseClient, - RpcHandler rpcHandler) { + RpcHandler rpcHandler, + Long maxChunksBeingTransferred) { this.channel = channel; this.reverseClient = reverseClient; this.rpcHandler = rpcHandler; this.streamManager = rpcHandler.getStreamManager(); + this.maxChunksBeingTransferred = maxChunksBeingTransferred; } @Override @@ -117,7 +123,13 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { logger.trace("Received req from {} to fetch block {}", getRemoteAddress(channel), req.streamChunkId); } - + long chunksBeingTransferred = streamManager.chunksBeingTransferred(); + if (chunksBeingTransferred >= maxChunksBeingTransferred) { + logger.warn("The number of chunks being transferred {} is above {}, close the connection.", + chunksBeingTransferred, maxChunksBeingTransferred); + channel.close(); + return; + } ManagedBuffer buf; try { streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId); @@ -130,10 +142,25 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { return; } - respond(new ChunkFetchSuccess(req.streamChunkId, buf)); + streamManager.chunkBeingSent(req.streamChunkId.streamId); + respond(new ChunkFetchSuccess(req.streamChunkId, buf)).addListener(future -> { + streamManager.chunkSent(req.streamChunkId.streamId); + }); } private void processStreamRequest(final StreamRequest req) { + if (logger.isTraceEnabled()) { + logger.trace("Received req from {} to fetch stream {}", getRemoteAddress(channel), + req.streamId); + } + + long chunksBeingTransferred = streamManager.chunksBeingTransferred(); + if (chunksBeingTransferred >= maxChunksBeingTransferred) { + logger.warn("The number of chunks being transferred {} is above {}, close the connection.", + chunksBeingTransferred, maxChunksBeingTransferred); + channel.close(); + return; + } ManagedBuffer buf; try { buf = streamManager.openStream(req.streamId); @@ -145,7 +172,10 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { } if (buf != null) { - respond(new StreamResponse(req.streamId, buf.size(), buf)); + streamManager.streamBeingSent(req.streamId); + respond(new StreamResponse(req.streamId, buf.size(), buf)).addListener(future -> { + streamManager.streamSent(req.streamId); + }); } else { respond(new StreamFailure(req.streamId, String.format( "Stream '%s' was not found.", req.streamId))); @@ -187,9 +217,9 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> { * Responds to a single message with some Encodable object. If a failure occurs while sending, * it will be logged and the channel closed. */ - private void respond(Encodable result) { + private ChannelFuture respond(Encodable result) { SocketAddress remoteAddress = channel.remoteAddress(); - channel.writeAndFlush(result).addListener(future -> { + return channel.writeAndFlush(result).addListener(future -> { if (future.isSuccess()) { logger.trace("Sent result {} to client {}", result, remoteAddress); } else { diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index a25078e262..ea52e9fe6c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -257,4 +257,10 @@ public class TransportConf { return CryptoUtils.toCryptoConf("spark.network.crypto.config.", conf.getAll()); } + /** + * The max number of chunks allowed to being transferred at the same time on shuffle service. + */ + public long maxChunksBeingTransferred() { + return conf.getLong("spark.shuffle.maxChunksBeingTransferred", Long.MAX_VALUE); + } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java new file mode 100644 index 0000000000..1fb987a8a7 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.util.ArrayList; +import java.util.List; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelPromise; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import org.junit.Test; + +import static org.mockito.Mockito.*; + +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.*; +import org.apache.spark.network.server.NoOpRpcHandler; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportRequestHandler; + +public class TransportRequestHandlerSuite { + + @Test + public void handleFetchRequestAndStreamRequest() throws Exception { + RpcHandler rpcHandler = new NoOpRpcHandler(); + OneForOneStreamManager streamManager = (OneForOneStreamManager) (rpcHandler.getStreamManager()); + Channel channel = mock(Channel.class); + List<Pair<Object, ExtendedChannelPromise>> responseAndPromisePairs = + new ArrayList<>(); + when(channel.writeAndFlush(any())) + .thenAnswer(invocationOnMock0 -> { + Object response = invocationOnMock0.getArguments()[0]; + ExtendedChannelPromise channelFuture = new ExtendedChannelPromise(channel); + responseAndPromisePairs.add(ImmutablePair.of(response, channelFuture)); + return channelFuture; + }); + + // Prepare the stream. + List<ManagedBuffer> managedBuffers = new ArrayList<>(); + managedBuffers.add(new TestManagedBuffer(10)); + managedBuffers.add(new TestManagedBuffer(20)); + managedBuffers.add(new TestManagedBuffer(30)); + managedBuffers.add(new TestManagedBuffer(40)); + long streamId = streamManager.registerStream("test-app", managedBuffers.iterator()); + streamManager.registerChannel(channel, streamId); + TransportClient reverseClient = mock(TransportClient.class); + TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient, + rpcHandler, 2L); + + RequestMessage request0 = new ChunkFetchRequest(new StreamChunkId(streamId, 0)); + requestHandler.handle(request0); + assert responseAndPromisePairs.size() == 1; + assert responseAndPromisePairs.get(0).getLeft() instanceof ChunkFetchSuccess; + assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(0).getLeft())).body() == + managedBuffers.get(0); + + RequestMessage request1 = new ChunkFetchRequest(new StreamChunkId(streamId, 1)); + requestHandler.handle(request1); + assert responseAndPromisePairs.size() == 2; + assert responseAndPromisePairs.get(1).getLeft() instanceof ChunkFetchSuccess; + assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(1).getLeft())).body() == + managedBuffers.get(1); + + // Finish flushing the response for request0. + responseAndPromisePairs.get(0).getRight().finish(true); + + RequestMessage request2 = new StreamRequest(String.format("%d_%d", streamId, 2)); + requestHandler.handle(request2); + assert responseAndPromisePairs.size() == 3; + assert responseAndPromisePairs.get(2).getLeft() instanceof StreamResponse; + assert ((StreamResponse) (responseAndPromisePairs.get(2).getLeft())).body() == + managedBuffers.get(2); + + // Request3 will trigger the close of channel, because the number of max chunks being + // transferred is 2; + RequestMessage request3 = new StreamRequest(String.format("%d_%d", streamId, 3)); + requestHandler.handle(request3); + verify(channel, times(1)).close(); + assert responseAndPromisePairs.size() == 3; + } + + private class ExtendedChannelPromise extends DefaultChannelPromise { + + private List<GenericFutureListener> listeners = new ArrayList<>(); + private boolean success; + + public ExtendedChannelPromise(Channel channel) { + super(channel); + success = false; + } + + @Override + public ChannelPromise addListener( + GenericFutureListener<? extends Future<? super Void>> listener) { + listeners.add(listener); + return super.addListener(listener); + } + + @Override + public boolean isSuccess() { + return success; + } + + public void finish(boolean success) { + this.success = success; + listeners.forEach(listener -> { + try { + listener.operationComplete(this); + } catch (Exception e) { } + }); + } + } +} diff --git a/docs/configuration.md b/docs/configuration.md index d3df923c42..f4b6f46db5 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -631,6 +631,13 @@ Apart from these, the following properties are also available, and may be useful Max number of entries to keep in the index cache of the shuffle service. </td> </tr> +<tr> + <td><code>spark.shuffle.maxChunksBeingTransferred</code></td> + <td>Long.MAX_VALUE</td> + <td> + The max number of chunks allowed to being transferred at the same time on shuffle service. + </td> +</tr> <tr> <td><code>spark.shuffle.sort.bypassMergeThreshold</code></td> <td>200</td> -- GitLab