diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index fcefe64d59c91bf8a896b3f0ad096fc977e6beb3..ca99fa89ebe1b72a028e4efeb1d08f84149d67b7 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -76,6 +76,10 @@ <artifactId>guava</artifactId> <scope>compile</scope> </dependency> + <dependency> + <groupId>org.apache.commons</groupId> + <artifactId>commons-crypto</artifactId> + </dependency> <!-- Test dependencies --> <dependency> diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 9e5c616ee5a1fd4fbf2e3280bdc16b13867af42e..a1bb453657460c9417ebfdacada581e06c15d39f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -30,6 +30,8 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.sasl.aes.AesCipher; +import org.apache.spark.network.sasl.aes.AesConfigMessage; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; @@ -88,9 +90,26 @@ public class SaslClientBootstrap implements TransportClientBootstrap { throw new RuntimeException( new SaslException("Encryption requests by negotiated non-encrypted connection.")); } - SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize()); + + if (conf.aesEncryptionEnabled()) { + // Generate a request config message to send to server. + AesConfigMessage configMessage = AesCipher.createConfigMessage(conf); + ByteBuffer buf = configMessage.encodeMessage(); + + // Encrypted the config message. + byte[] toEncrypt = JavaUtils.bufferToArray(buf); + ByteBuffer encrypted = ByteBuffer.wrap(saslClient.wrap(toEncrypt, 0, toEncrypt.length)); + + client.sendRpcSync(encrypted, conf.saslRTTimeoutMs()); + AesCipher cipher = new AesCipher(configMessage, conf); + logger.info("Enabling AES cipher for client channel {}", client); + cipher.addToChannel(channel); + saslClient.dispose(); + } else { + SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize()); + } saslClient = null; - logger.debug("Channel {} configured for SASL encryption.", client); + logger.debug("Channel {} configured for encryption.", client); } } catch (IOException ioe) { throw new RuntimeException(ioe); diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index c41f5b6873f6cc9e91f07e271a907c476f3bc65b..b2f3ef214b7acafbd7d703cd95e3c19b9c6bb0b2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -29,6 +29,8 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.sasl.aes.AesCipher; +import org.apache.spark.network.sasl.aes.AesConfigMessage; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.util.JavaUtils; @@ -59,6 +61,7 @@ class SaslRpcHandler extends RpcHandler { private SparkSaslServer saslServer; private boolean isComplete; + private boolean isAuthenticated; SaslRpcHandler( TransportConf conf, @@ -71,6 +74,7 @@ class SaslRpcHandler extends RpcHandler { this.secretKeyHolder = secretKeyHolder; this.saslServer = null; this.isComplete = false; + this.isAuthenticated = false; } @Override @@ -80,30 +84,31 @@ class SaslRpcHandler extends RpcHandler { delegate.receive(client, message, callback); return; } + if (saslServer == null || !saslServer.isComplete()) { + ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); + SaslMessage saslMessage; + try { + saslMessage = SaslMessage.decode(nettyBuf); + } finally { + nettyBuf.release(); + } - ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); - SaslMessage saslMessage; - try { - saslMessage = SaslMessage.decode(nettyBuf); - } finally { - nettyBuf.release(); - } - - if (saslServer == null) { - // First message in the handshake, setup the necessary state. - client.setClientId(saslMessage.appId); - saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder, - conf.saslServerAlwaysEncrypt()); - } + if (saslServer == null) { + // First message in the handshake, setup the necessary state. + client.setClientId(saslMessage.appId); + saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder, + conf.saslServerAlwaysEncrypt()); + } - byte[] response; - try { - response = saslServer.response(JavaUtils.bufferToArray( - saslMessage.body().nioByteBuffer())); - } catch (IOException ioe) { - throw new RuntimeException(ioe); + byte[] response; + try { + response = saslServer.response(JavaUtils.bufferToArray( + saslMessage.body().nioByteBuffer())); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + callback.onSuccess(ByteBuffer.wrap(response)); } - callback.onSuccess(ByteBuffer.wrap(response)); // Setup encryption after the SASL response is sent, otherwise the client can't parse the // response. It's ok to change the channel pipeline here since we are processing an incoming @@ -111,15 +116,42 @@ class SaslRpcHandler extends RpcHandler { // method returns. This assumes that the code ensures, through other means, that no outbound // messages are being written to the channel while negotiation is still going on. if (saslServer.isComplete()) { - logger.debug("SASL authentication successful for channel {}", client); - isComplete = true; - if (SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) { + if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) { + logger.debug("SASL authentication successful for channel {}", client); + complete(true); + return; + } + + if (!conf.aesEncryptionEnabled()) { logger.debug("Enabling encryption for channel {}", client); SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize()); - saslServer = null; - } else { - saslServer.dispose(); - saslServer = null; + complete(false); + return; + } + + // Extra negotiation should happen after authentication, so return directly while + // processing authenticate. + if (!isAuthenticated) { + logger.debug("SASL authentication successful for channel {}", client); + isAuthenticated = true; + return; + } + + // Create AES cipher when it is authenticated + try { + byte[] encrypted = JavaUtils.bufferToArray(message); + ByteBuffer decrypted = ByteBuffer.wrap(saslServer.unwrap(encrypted, 0 , encrypted.length)); + + AesConfigMessage configMessage = AesConfigMessage.decodeMessage(decrypted); + AesCipher cipher = new AesCipher(configMessage, conf); + + // Send response back to client to confirm that server accept config. + callback.onSuccess(JavaUtils.stringToBytes(AesCipher.TRANSFORM)); + logger.info("Enabling AES cipher for Server channel {}", client); + cipher.addToChannel(channel); + complete(true); + } catch (IOException ioe) { + throw new RuntimeException(ioe); } } } @@ -155,4 +187,17 @@ class SaslRpcHandler extends RpcHandler { delegate.exceptionCaught(cause, client); } + private void complete(boolean dispose) { + if (dispose) { + try { + saslServer.dispose(); + } catch (RuntimeException e) { + logger.error("Error while disposing SASL server", e); + } + } + + saslServer = null; + isComplete = true; + } + } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java new file mode 100644 index 0000000000000000000000000000000000000000..78034a69f734d403704da203cc0d5f9f2a29965f --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl.aes; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; +import java.util.Properties; +import javax.crypto.spec.SecretKeySpec; +import javax.crypto.spec.IvParameterSpec; + +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.*; +import io.netty.util.AbstractReferenceCounted; +import org.apache.commons.crypto.cipher.CryptoCipherFactory; +import org.apache.commons.crypto.random.CryptoRandom; +import org.apache.commons.crypto.random.CryptoRandomFactory; +import org.apache.commons.crypto.stream.CryptoInputStream; +import org.apache.commons.crypto.stream.CryptoOutputStream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.util.ByteArrayReadableChannel; +import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.TransportConf; + +/** + * AES cipher for encryption and decryption. + */ +public class AesCipher { + private static final Logger logger = LoggerFactory.getLogger(AesCipher.class); + public static final String ENCRYPTION_HANDLER_NAME = "AesEncryption"; + public static final String DECRYPTION_HANDLER_NAME = "AesDecryption"; + public static final int STREAM_BUFFER_SIZE = 1024 * 32; + public static final String TRANSFORM = "AES/CTR/NoPadding"; + + private final SecretKeySpec inKeySpec; + private final IvParameterSpec inIvSpec; + private final SecretKeySpec outKeySpec; + private final IvParameterSpec outIvSpec; + private final Properties properties; + + public AesCipher(AesConfigMessage configMessage, TransportConf conf) throws IOException { + this.properties = CryptoStreamUtils.toCryptoConf(conf); + this.inKeySpec = new SecretKeySpec(configMessage.inKey, "AES"); + this.inIvSpec = new IvParameterSpec(configMessage.inIv); + this.outKeySpec = new SecretKeySpec(configMessage.outKey, "AES"); + this.outIvSpec = new IvParameterSpec(configMessage.outIv); + } + + /** + * Create AES crypto output stream + * @param ch The underlying channel to write out. + * @return Return output crypto stream for encryption. + * @throws IOException + */ + private CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException { + return new CryptoOutputStream(TRANSFORM, properties, ch, outKeySpec, outIvSpec); + } + + /** + * Create AES crypto input stream + * @param ch The underlying channel used to read data. + * @return Return input crypto stream for decryption. + * @throws IOException + */ + private CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException { + return new CryptoInputStream(TRANSFORM, properties, ch, inKeySpec, inIvSpec); + } + + /** + * Add handlers to channel + * @param ch the channel for adding handlers + * @throws IOException + */ + public void addToChannel(Channel ch) throws IOException { + ch.pipeline() + .addFirst(ENCRYPTION_HANDLER_NAME, new AesEncryptHandler(this)) + .addFirst(DECRYPTION_HANDLER_NAME, new AesDecryptHandler(this)); + } + + /** + * Create the configuration message + * @param conf is the local transport configuration. + * @return Config message for sending. + */ + public static AesConfigMessage createConfigMessage(TransportConf conf) { + int keySize = conf.aesCipherKeySize(); + Properties properties = CryptoStreamUtils.toCryptoConf(conf); + + try { + int paramLen = CryptoCipherFactory.getCryptoCipher(AesCipher.TRANSFORM, properties) + .getBlockSize(); + byte[] inKey = new byte[keySize]; + byte[] outKey = new byte[keySize]; + byte[] inIv = new byte[paramLen]; + byte[] outIv = new byte[paramLen]; + + CryptoRandom random = CryptoRandomFactory.getCryptoRandom(properties); + random.nextBytes(inKey); + random.nextBytes(outKey); + random.nextBytes(inIv); + random.nextBytes(outIv); + + return new AesConfigMessage(inKey, inIv, outKey, outIv); + } catch (Exception e) { + logger.error("AES config error", e); + throw Throwables.propagate(e); + } + } + + /** + * CryptoStreamUtils is used to convert config from TransportConf to AES Crypto config. + */ + private static class CryptoStreamUtils { + public static Properties toCryptoConf(TransportConf conf) { + Properties props = new Properties(); + if (conf.aesCipherClass() != null) { + props.setProperty(CryptoCipherFactory.CLASSES_KEY, conf.aesCipherClass()); + } + return props; + } + } + + private static class AesEncryptHandler extends ChannelOutboundHandlerAdapter { + private final ByteArrayWritableChannel byteChannel; + private final CryptoOutputStream cos; + + AesEncryptHandler(AesCipher cipher) throws IOException { + byteChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE); + cos = cipher.createOutputStream(byteChannel); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + ctx.write(new EncryptedMessage(cos, msg, byteChannel), promise); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + try { + cos.close(); + } finally { + super.close(ctx, promise); + } + } + } + + private static class AesDecryptHandler extends ChannelInboundHandlerAdapter { + private final CryptoInputStream cis; + private final ByteArrayReadableChannel byteChannel; + + AesDecryptHandler(AesCipher cipher) throws IOException { + byteChannel = new ByteArrayReadableChannel(); + cis = cipher.createInputStream(byteChannel); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { + byteChannel.feedData((ByteBuf) data); + + byte[] decryptedData = new byte[byteChannel.readableBytes()]; + int offset = 0; + while (offset < decryptedData.length) { + offset += cis.read(decryptedData, offset, decryptedData.length - offset); + } + + ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, decryptedData.length)); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + try { + cis.close(); + } finally { + super.channelInactive(ctx); + } + } + } + + private static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion { + private final boolean isByteBuf; + private final ByteBuf buf; + private final FileRegion region; + private long transferred; + private CryptoOutputStream cos; + + // Due to streaming issue CRYPTO-125: https://issues.apache.org/jira/browse/CRYPTO-125, it has + // to utilize two helper ByteArrayWritableChannel for streaming. One is used to receive raw data + // from upper handler, another is used to store encrypted data. + private ByteArrayWritableChannel byteEncChannel; + private ByteArrayWritableChannel byteRawChannel; + + private ByteBuffer currentEncrypted; + + EncryptedMessage(CryptoOutputStream cos, Object msg, ByteArrayWritableChannel ch) { + Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion, + "Unrecognized message type: %s", msg.getClass().getName()); + this.isByteBuf = msg instanceof ByteBuf; + this.buf = isByteBuf ? (ByteBuf) msg : null; + this.region = isByteBuf ? null : (FileRegion) msg; + this.transferred = 0; + this.byteRawChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE); + this.cos = cos; + this.byteEncChannel = ch; + } + + @Override + public long count() { + return isByteBuf ? buf.readableBytes() : region.count(); + } + + @Override + public long position() { + return 0; + } + + @Override + public long transfered() { + return transferred; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + Preconditions.checkArgument(position == transfered(), "Invalid position."); + + do { + if (currentEncrypted == null) { + encryptMore(); + } + + int bytesWritten = currentEncrypted.remaining(); + target.write(currentEncrypted); + bytesWritten -= currentEncrypted.remaining(); + transferred += bytesWritten; + if (!currentEncrypted.hasRemaining()) { + currentEncrypted = null; + byteEncChannel.reset(); + } + } while (transferred < count()); + + return transferred; + } + + private void encryptMore() throws IOException { + byteRawChannel.reset(); + + if (isByteBuf) { + int copied = byteRawChannel.write(buf.nioBuffer()); + buf.skipBytes(copied); + } else { + region.transferTo(byteRawChannel, region.transfered()); + } + cos.write(byteRawChannel.getData(), 0, byteRawChannel.length()); + cos.flush(); + + currentEncrypted = ByteBuffer.wrap(byteEncChannel.getData(), + 0, byteEncChannel.length()); + } + + @Override + protected void deallocate() { + byteRawChannel.reset(); + byteEncChannel.reset(); + if (region != null) { + region.release(); + } + if (buf != null) { + buf.release(); + } + } + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java new file mode 100644 index 0000000000000000000000000000000000000000..3ef6f74a1f89f8701295c6e529e0a90d4950fb2f --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl.aes; + +import java.nio.ByteBuffer; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; + +/** + * The AES cipher options for encryption negotiation. + */ +public class AesConfigMessage implements Encodable { + /** Serialization tag used to catch incorrect payloads. */ + private static final byte TAG_BYTE = (byte) 0xEB; + + public byte[] inKey; + public byte[] outKey; + public byte[] inIv; + public byte[] outIv; + + public AesConfigMessage(byte[] inKey, byte[] inIv, byte[] outKey, byte[] outIv) { + if (inKey == null || inIv == null || outKey == null || outIv == null) { + throw new IllegalArgumentException("Cipher Key or IV must not be null!"); + } + + this.inKey = inKey; + this.inIv = inIv; + this.outKey = outKey; + this.outIv = outIv; + } + + @Override + public int encodedLength() { + return 1 + + Encoders.ByteArrays.encodedLength(inKey) + Encoders.ByteArrays.encodedLength(outKey) + + Encoders.ByteArrays.encodedLength(inIv) + Encoders.ByteArrays.encodedLength(outIv); + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(TAG_BYTE); + Encoders.ByteArrays.encode(buf, inKey); + Encoders.ByteArrays.encode(buf, inIv); + Encoders.ByteArrays.encode(buf, outKey); + Encoders.ByteArrays.encode(buf, outIv); + } + + /** + * Encode the config message. + * @return ByteBuffer which contains encoded config message. + */ + public ByteBuffer encodeMessage(){ + ByteBuffer buf = ByteBuffer.allocate(encodedLength()); + + ByteBuf wrappedBuf = Unpooled.wrappedBuffer(buf); + wrappedBuf.clear(); + encode(wrappedBuf); + + return buf; + } + + /** + * Decode the config message from buffer + * @param buffer the buffer contain encoded config message + * @return config message + */ + public static AesConfigMessage decodeMessage(ByteBuffer buffer) { + ByteBuf buf = Unpooled.wrappedBuffer(buffer); + + if (buf.readByte() != TAG_BYTE) { + throw new IllegalStateException("Expected AesConfigMessage, received something else" + + " (maybe your client does not have AES enabled?)"); + } + + byte[] outKey = Encoders.ByteArrays.decode(buf); + byte[] outIv = Encoders.ByteArrays.decode(buf); + byte[] inKey = Encoders.ByteArrays.decode(buf); + byte[] inIv = Encoders.ByteArrays.decode(buf); + return new AesConfigMessage(inKey, inIv, outKey, outIv); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java new file mode 100644 index 0000000000000000000000000000000000000000..25d103d0e316f0870d0c14c0e9f7b69f020fa7d9 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; + +import io.netty.buffer.ByteBuf; + +public class ByteArrayReadableChannel implements ReadableByteChannel { + private ByteBuf data; + + public int readableBytes() { + return data.readableBytes(); + } + + public void feedData(ByteBuf buf) { + data = buf; + } + + @Override + public int read(ByteBuffer dst) throws IOException { + int totalRead = 0; + while (data.readableBytes() > 0 && dst.remaining() > 0) { + int bytesToRead = Math.min(data.readableBytes(), dst.remaining()); + dst.put(data.readSlice(bytesToRead).nioBuffer()); + totalRead += bytesToRead; + } + + if (data.readableBytes() == 0) { + data.release(); + } + + return totalRead; + } + + @Override + public void close() throws IOException { + } + + @Override + public boolean isOpen() { + return true; + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 64eaba103cccb865e29b6cf24d8245c3929cd293..d0d072849d38455465b50be453d60f8bbdf9af63 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -18,6 +18,7 @@ package org.apache.spark.network.util; import com.google.common.primitives.Ints; +import org.apache.commons.crypto.cipher.CryptoCipherFactory; /** * A central location that tracks all the settings we expose to users. @@ -175,4 +176,25 @@ public class TransportConf { return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false); } + /** + * The trigger for enabling AES encryption. + */ + public boolean aesEncryptionEnabled() { + return conf.getBoolean("spark.authenticate.encryption.aes.enabled", false); + } + + /** + * The implementation class for crypto cipher + */ + public String aesCipherClass() { + return conf.get("spark.authenticate.encryption.aes.cipher.class", null); + } + + /** + * The bytes of AES cipher key which is effective when AES cipher is enabled. Notice that + * the length should be 16, 24 or 32 bytes. + */ + public int aesCipherKeySize() { + return conf.getInt("spark.authenticate.encryption.aes.cipher.keySize", 16); + } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 45cc03df435ac5a404f39126ce0dd47280be8974..4e6146cf070d0bdc27545cb17701e838e3aca9a6 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -53,6 +53,7 @@ import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.sasl.aes.AesCipher; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; @@ -149,7 +150,7 @@ public class SparkSaslSuite { .when(rpcHandler) .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class)); - SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false); + SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false, false); try { ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), TimeUnit.SECONDS.toMillis(10)); @@ -275,7 +276,7 @@ public class SparkSaslSuite { new Random().nextBytes(data); Files.write(data, file); - ctx = new SaslTestCtx(rpcHandler, true, false); + ctx = new SaslTestCtx(rpcHandler, true, false, false); final CountDownLatch lock = new CountDownLatch(1); @@ -317,7 +318,7 @@ public class SparkSaslSuite { SaslTestCtx ctx = null; try { - ctx = new SaslTestCtx(mock(RpcHandler.class), false, false); + ctx = new SaslTestCtx(mock(RpcHandler.class), false, false, false); fail("Should have failed to connect without encryption."); } catch (Exception e) { assertTrue(e.getCause() instanceof SaslException); @@ -336,7 +337,7 @@ public class SparkSaslSuite { // able to understand RPCs sent to it and thus close the connection. SaslTestCtx ctx = null; try { - ctx = new SaslTestCtx(mock(RpcHandler.class), true, true); + ctx = new SaslTestCtx(mock(RpcHandler.class), true, true, false); ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), TimeUnit.SECONDS.toMillis(10)); fail("Should have failed to send RPC to server."); @@ -374,6 +375,69 @@ public class SparkSaslSuite { } } + @Test + public void testAesEncryption() throws Exception { + final AtomicReference<ManagedBuffer> response = new AtomicReference<>(); + final File file = File.createTempFile("sasltest", ".txt"); + SaslTestCtx ctx = null; + try { + final TransportConf conf = new TransportConf("rpc", new SystemPropertyConfigProvider()); + final TransportConf spyConf = spy(conf); + doReturn(true).when(spyConf).aesEncryptionEnabled(); + + StreamManager sm = mock(StreamManager.class); + when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer<ManagedBuffer>() { + @Override + public ManagedBuffer answer(InvocationOnMock invocation) { + return new FileSegmentManagedBuffer(spyConf, file, 0, file.length()); + } + }); + + RpcHandler rpcHandler = mock(RpcHandler.class); + when(rpcHandler.getStreamManager()).thenReturn(sm); + + byte[] data = new byte[256 * 1024 * 1024]; + new Random().nextBytes(data); + Files.write(data, file); + + ctx = new SaslTestCtx(rpcHandler, true, false, true); + + final Object lock = new Object(); + + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + doAnswer(new Answer<Void>() { + @Override + public Void answer(InvocationOnMock invocation) { + response.set((ManagedBuffer) invocation.getArguments()[1]); + response.get().retain(); + synchronized (lock) { + lock.notifyAll(); + } + return null; + } + }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class)); + + synchronized (lock) { + ctx.client.fetchChunk(0, 0, callback); + lock.wait(10 * 1000); + } + + verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class)); + verify(callback, never()).onFailure(anyInt(), any(Throwable.class)); + + byte[] received = ByteStreams.toByteArray(response.get().createInputStream()); + assertTrue(Arrays.equals(data, received)); + } finally { + file.delete(); + if (ctx != null) { + ctx.close(); + } + if (response.get() != null) { + response.get().release(); + } + } + } + private static class SaslTestCtx { final TransportClient client; @@ -386,18 +450,28 @@ public class SparkSaslSuite { SaslTestCtx( RpcHandler rpcHandler, boolean encrypt, - boolean disableClientEncryption) + boolean disableClientEncryption, + boolean aesEnable) throws Exception { TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + if (aesEnable) { + conf = spy(conf); + doReturn(true).when(conf).aesEncryptionEnabled(); + } + SecretKeyHolder keyHolder = mock(SecretKeyHolder.class); when(keyHolder.getSaslUser(anyString())).thenReturn("user"); when(keyHolder.getSecretKey(anyString())).thenReturn("secret"); TransportContext ctx = new TransportContext(conf, rpcHandler); - this.checker = new EncryptionCheckerBootstrap(); + String encryptHandlerName = aesEnable ? AesCipher.ENCRYPTION_HANDLER_NAME : + SaslEncryption.ENCRYPTION_HANDLER_NAME; + + this.checker = new EncryptionCheckerBootstrap(encryptHandlerName); + this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder), checker)); @@ -437,13 +511,18 @@ public class SparkSaslSuite { implements TransportServerBootstrap { boolean foundEncryptionHandler; + String encryptHandlerName; + + public EncryptionCheckerBootstrap(String encryptHandlerName) { + this.encryptHandlerName = encryptHandlerName; + } @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { if (!foundEncryptionHandler) { foundEncryptionHandler = - ctx.channel().pipeline().get(SaslEncryption.ENCRYPTION_HANDLER_NAME) != null; + ctx.channel().pipeline().get(encryptHandlerName) != null; } ctx.write(msg, promise); } diff --git a/docs/configuration.md b/docs/configuration.md index d0acd944dd6b9af65f8ff547b2c6a85c7b06d350..41c1778ee7fcf719bf682603cdecfb97bcdd6187 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1529,6 +1529,32 @@ Apart from these, the following properties are also available, and may be useful currently supported by the external shuffle service. </td> </tr> +<tr> + <td><code>spark.authenticate.encryption.aes.enabled</code></td> + <td>false</td> + <td> + Enable AES for over-the-wire encryption + </td> +</tr> +<tr> + <td><code>spark.authenticate.encryption.aes.cipher.keySize</code></td> + <td>16</td> + <td> + The bytes of AES cipher key which is effective when AES cipher is enabled. AES + works with 16, 24 and 32 bytes keys. + </td> +</tr> +<tr> + <td><code>spark.authenticate.encryption.aes.cipher.class</code></td> + <td>null</td> + <td> + Specify the underlying implementation class of crypto cipher. Set null here to use default. + In order to use OpenSslCipher users should install openssl. Currently, there are two cipher + classes available in Commons Crypto library: + org.apache.commons.crypto.cipher.OpenSslCipher + org.apache.commons.crypto.cipher.JceCipher + </td> +</tr> <tr> <td><code>spark.core.connection.ack.wait.timeout</code></td> <td>60s</td>