diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java new file mode 100644 index 0000000000000000000000000000000000000000..980525dbf04e0eab27495156838d16566b512464 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.security.GeneralSecurityException; +import java.security.Key; +import javax.crypto.KeyGenerator; +import javax.crypto.Mac; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.sasl.SaslClientBootstrap; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * Bootstraps a {@link TransportClient} by performing authentication using Spark's auth protocol. + * + * This bootstrap falls back to using the SASL bootstrap if the server throws an error during + * authentication, and the configuration allows it. This is used for backwards compatibility + * with external shuffle services that do not support the new protocol. + * + * It also automatically falls back to SASL if the new encryption backend is disabled, so that + * callers only need to install this bootstrap when authentication is enabled. + */ +public class AuthClientBootstrap implements TransportClientBootstrap { + + private static final Logger LOG = LoggerFactory.getLogger(AuthClientBootstrap.class); + + private final TransportConf conf; + private final String appId; + private final String authUser; + private final SecretKeyHolder secretKeyHolder; + + public AuthClientBootstrap( + TransportConf conf, + String appId, + SecretKeyHolder secretKeyHolder) { + this.conf = conf; + // TODO: right now this behaves like the SASL backend, because when executors start up + // they don't necessarily know the app ID. So they send a hardcoded "user" that is defined + // in the SecurityManager, which will also always return the same secret (regardless of the + // user name). All that's needed here is for this "user" to match on both sides, since that's + // required by the protocol. At some point, though, it would be better for the actual app ID + // to be provided here. + this.appId = appId; + this.authUser = secretKeyHolder.getSaslUser(appId); + this.secretKeyHolder = secretKeyHolder; + } + + @Override + public void doBootstrap(TransportClient client, Channel channel) { + if (!conf.encryptionEnabled()) { + LOG.debug("AES encryption disabled, using old auth protocol."); + doSaslAuth(client, channel); + return; + } + + try { + doSparkAuth(client, channel); + } catch (GeneralSecurityException | IOException e) { + throw Throwables.propagate(e); + } catch (RuntimeException e) { + // There isn't a good exception that can be caught here to know whether it's really + // OK to switch back to SASL (because the server doesn't speak the new protocol). So + // try it anyway, and in the worst case things will fail again. + if (conf.saslFallback()) { + LOG.warn("New auth protocol failed, trying SASL.", e); + doSaslAuth(client, channel); + } else { + throw e; + } + } + } + + private void doSparkAuth(TransportClient client, Channel channel) + throws GeneralSecurityException, IOException { + + AuthEngine engine = new AuthEngine(authUser, secretKeyHolder.getSecretKey(authUser), conf); + try { + ClientChallenge challenge = engine.challenge(); + ByteBuf challengeData = Unpooled.buffer(challenge.encodedLength()); + challenge.encode(challengeData); + + ByteBuffer responseData = client.sendRpcSync(challengeData.nioBuffer(), + conf.authRTTimeoutMs()); + ServerResponse response = ServerResponse.decodeMessage(responseData); + + engine.validate(response); + engine.sessionCipher().addToChannel(channel); + } finally { + engine.close(); + } + } + + private void doSaslAuth(TransportClient client, Channel channel) { + SaslClientBootstrap sasl = new SaslClientBootstrap(conf, appId, secretKeyHolder); + sasl.doBootstrap(client, channel); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java new file mode 100644 index 0000000000000000000000000000000000000000..b769ebeba36ccfcb2d8d37d83e46469bd615addd --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.io.Closeable; +import java.io.IOException; +import java.math.BigInteger; +import java.security.GeneralSecurityException; +import java.util.Arrays; +import java.util.Properties; +import javax.crypto.Cipher; +import javax.crypto.SecretKey; +import javax.crypto.SecretKeyFactory; +import javax.crypto.ShortBufferException; +import javax.crypto.spec.IvParameterSpec; +import javax.crypto.spec.PBEKeySpec; +import javax.crypto.spec.SecretKeySpec; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.primitives.Bytes; +import org.apache.commons.crypto.cipher.CryptoCipher; +import org.apache.commons.crypto.cipher.CryptoCipherFactory; +import org.apache.commons.crypto.random.CryptoRandom; +import org.apache.commons.crypto.random.CryptoRandomFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.util.TransportConf; + +/** + * A helper class for abstracting authentication and key negotiation details. This is used by + * both client and server sides, since the operations are basically the same. + */ +class AuthEngine implements Closeable { + + private static final Logger LOG = LoggerFactory.getLogger(AuthEngine.class); + private static final BigInteger ONE = new BigInteger(new byte[] { 0x1 }); + + private final byte[] appId; + private final char[] secret; + private final TransportConf conf; + private final Properties cryptoConf; + private final CryptoRandom random; + + private byte[] authNonce; + + @VisibleForTesting + byte[] challenge; + + private TransportCipher sessionCipher; + private CryptoCipher encryptor; + private CryptoCipher decryptor; + + AuthEngine(String appId, String secret, TransportConf conf) throws GeneralSecurityException { + this.appId = appId.getBytes(UTF_8); + this.conf = conf; + this.cryptoConf = conf.cryptoConf(); + this.secret = secret.toCharArray(); + this.random = CryptoRandomFactory.getCryptoRandom(cryptoConf); + } + + /** + * Create the client challenge. + * + * @return A challenge to be sent the remote side. + */ + ClientChallenge challenge() throws GeneralSecurityException, IOException { + this.authNonce = randomBytes(conf.encryptionKeyLength() / Byte.SIZE); + SecretKeySpec authKey = generateKey(conf.keyFactoryAlgorithm(), conf.keyFactoryIterations(), + authNonce, conf.encryptionKeyLength()); + initializeForAuth(conf.cipherTransformation(), authNonce, authKey); + + this.challenge = randomBytes(conf.encryptionKeyLength() / Byte.SIZE); + return new ClientChallenge(new String(appId, UTF_8), + conf.keyFactoryAlgorithm(), + conf.keyFactoryIterations(), + conf.cipherTransformation(), + conf.encryptionKeyLength(), + authNonce, + challenge(appId, authNonce, challenge)); + } + + /** + * Validates the client challenge, and create the encryption backend for the channel from the + * parameters sent by the client. + * + * @param clientChallenge The challenge from the client. + * @return A response to be sent to the client. + */ + ServerResponse respond(ClientChallenge clientChallenge) + throws GeneralSecurityException, IOException { + + SecretKeySpec authKey = generateKey(clientChallenge.kdf, clientChallenge.iterations, + clientChallenge.nonce, clientChallenge.keyLength); + initializeForAuth(clientChallenge.cipher, clientChallenge.nonce, authKey); + + byte[] challenge = validateChallenge(clientChallenge.nonce, clientChallenge.challenge); + byte[] response = challenge(appId, clientChallenge.nonce, rawResponse(challenge)); + byte[] sessionNonce = randomBytes(conf.encryptionKeyLength() / Byte.SIZE); + byte[] inputIv = randomBytes(conf.ivLength()); + byte[] outputIv = randomBytes(conf.ivLength()); + + SecretKeySpec sessionKey = generateKey(clientChallenge.kdf, clientChallenge.iterations, + sessionNonce, clientChallenge.keyLength); + this.sessionCipher = new TransportCipher(cryptoConf, clientChallenge.cipher, sessionKey, + inputIv, outputIv); + + // Note the IVs are swapped in the response. + return new ServerResponse(response, encrypt(sessionNonce), encrypt(outputIv), encrypt(inputIv)); + } + + /** + * Validates the server response and initializes the cipher to use for the session. + * + * @param serverResponse The response from the server. + */ + void validate(ServerResponse serverResponse) throws GeneralSecurityException { + byte[] response = validateChallenge(authNonce, serverResponse.response); + + byte[] expected = rawResponse(challenge); + Preconditions.checkArgument(Arrays.equals(expected, response)); + + byte[] nonce = decrypt(serverResponse.nonce); + byte[] inputIv = decrypt(serverResponse.inputIv); + byte[] outputIv = decrypt(serverResponse.outputIv); + + SecretKeySpec sessionKey = generateKey(conf.keyFactoryAlgorithm(), conf.keyFactoryIterations(), + nonce, conf.encryptionKeyLength()); + this.sessionCipher = new TransportCipher(cryptoConf, conf.cipherTransformation(), sessionKey, + inputIv, outputIv); + } + + TransportCipher sessionCipher() { + Preconditions.checkState(sessionCipher != null); + return sessionCipher; + } + + @Override + public void close() throws IOException { + // Close ciphers (by calling "doFinal()" with dummy data) and the random instance so that + // internal state is cleaned up. Error handling here is just for paranoia, and not meant to + // accurately report the errors when they happen. + RuntimeException error = null; + byte[] dummy = new byte[8]; + try { + doCipherOp(encryptor, dummy, true); + } catch (Exception e) { + error = new RuntimeException(e); + } + try { + doCipherOp(decryptor, dummy, true); + } catch (Exception e) { + error = new RuntimeException(e); + } + random.close(); + + if (error != null) { + throw error; + } + } + + @VisibleForTesting + byte[] challenge(byte[] appId, byte[] nonce, byte[] challenge) throws GeneralSecurityException { + return encrypt(Bytes.concat(appId, nonce, challenge)); + } + + @VisibleForTesting + byte[] rawResponse(byte[] challenge) { + BigInteger orig = new BigInteger(challenge); + BigInteger response = orig.add(ONE); + return response.toByteArray(); + } + + private byte[] decrypt(byte[] in) throws GeneralSecurityException { + return doCipherOp(decryptor, in, false); + } + + private byte[] encrypt(byte[] in) throws GeneralSecurityException { + return doCipherOp(encryptor, in, false); + } + + private void initializeForAuth(String cipher, byte[] nonce, SecretKeySpec key) + throws GeneralSecurityException { + + // commons-crypto currently only supports ciphers that require an initial vector; so + // create a dummy vector so that we can initialize the ciphers. In the future, if + // different ciphers are supported, this will have to be configurable somehow. + byte[] iv = new byte[conf.ivLength()]; + System.arraycopy(nonce, 0, iv, 0, Math.min(nonce.length, iv.length)); + + encryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf); + encryptor.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(iv)); + + decryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf); + decryptor.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv)); + } + + /** + * Validates an encrypted challenge as defined in the protocol, and returns the byte array + * that corresponds to the actual challenge data. + */ + private byte[] validateChallenge(byte[] nonce, byte[] encryptedChallenge) + throws GeneralSecurityException { + + byte[] challenge = decrypt(encryptedChallenge); + checkSubArray(appId, challenge, 0); + checkSubArray(nonce, challenge, appId.length); + return Arrays.copyOfRange(challenge, appId.length + nonce.length, challenge.length); + } + + private SecretKeySpec generateKey(String kdf, int iterations, byte[] salt, int keyLength) + throws GeneralSecurityException { + + SecretKeyFactory factory = SecretKeyFactory.getInstance(kdf); + PBEKeySpec spec = new PBEKeySpec(secret, salt, iterations, keyLength); + + long start = System.nanoTime(); + SecretKey key = factory.generateSecret(spec); + long end = System.nanoTime(); + + LOG.debug("Generated key with {} iterations in {} us.", conf.keyFactoryIterations(), + (end - start) / 1000); + + return new SecretKeySpec(key.getEncoded(), conf.keyAlgorithm()); + } + + private byte[] doCipherOp(CryptoCipher cipher, byte[] in, boolean isFinal) + throws GeneralSecurityException { + + Preconditions.checkState(cipher != null); + + int scale = 1; + while (true) { + int size = in.length * scale; + byte[] buffer = new byte[size]; + try { + int outSize = isFinal ? cipher.doFinal(in, 0, in.length, buffer, 0) + : cipher.update(in, 0, in.length, buffer, 0); + if (outSize != buffer.length) { + byte[] output = new byte[outSize]; + System.arraycopy(buffer, 0, output, 0, output.length); + return output; + } else { + return buffer; + } + } catch (ShortBufferException e) { + // Try again with a bigger buffer. + scale *= 2; + } + } + } + + private byte[] randomBytes(int count) { + byte[] bytes = new byte[count]; + random.nextBytes(bytes); + return bytes; + } + + /** Checks that the "test" array is in the data array starting at the given offset. */ + private void checkSubArray(byte[] test, byte[] data, int offset) { + Preconditions.checkArgument(data.length >= test.length + offset); + for (int i = 0; i < test.length; i++) { + Preconditions.checkArgument(test[i] == data[i + offset]); + } + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java new file mode 100644 index 0000000000000000000000000000000000000000..991d8ba95f5ee103aef84094fe7d33480de4253a --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.io.IOException; +import java.nio.ByteBuffer; +import javax.security.sasl.Sasl; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Throwables; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * RPC Handler which performs authentication using Spark's auth protocol before delegating to a + * child RPC handler. If the configuration allows, this handler will delegate messages to a SASL + * RPC handler for further authentication, to support for clients that do not support Spark's + * protocol. + * + * The delegate will only receive messages if the given connection has been successfully + * authenticated. A connection may be authenticated at most once. + */ +class AuthRpcHandler extends RpcHandler { + private static final Logger LOG = LoggerFactory.getLogger(AuthRpcHandler.class); + + /** Transport configuration. */ + private final TransportConf conf; + + /** The client channel. */ + private final Channel channel; + + /** + * RpcHandler we will delegate to for authenticated connections. When falling back to SASL + * this will be replaced with the SASL RPC handler. + */ + @VisibleForTesting + RpcHandler delegate; + + /** Class which provides secret keys which are shared by server and client on a per-app basis. */ + private final SecretKeyHolder secretKeyHolder; + + /** Whether auth is done and future calls should be delegated. */ + @VisibleForTesting + boolean doDelegate; + + AuthRpcHandler( + TransportConf conf, + Channel channel, + RpcHandler delegate, + SecretKeyHolder secretKeyHolder) { + this.conf = conf; + this.channel = channel; + this.delegate = delegate; + this.secretKeyHolder = secretKeyHolder; + } + + @Override + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { + if (doDelegate) { + delegate.receive(client, message, callback); + return; + } + + int position = message.position(); + int limit = message.limit(); + + ClientChallenge challenge; + try { + challenge = ClientChallenge.decodeMessage(message); + LOG.debug("Received new auth challenge for client {}.", channel.remoteAddress()); + } catch (RuntimeException e) { + if (conf.saslFallback()) { + LOG.warn("Failed to parse new auth challenge, reverting to SASL for client {}.", + channel.remoteAddress()); + delegate = new SaslRpcHandler(conf, channel, delegate, secretKeyHolder); + message.position(position); + message.limit(limit); + delegate.receive(client, message, callback); + doDelegate = true; + } else { + LOG.debug("Unexpected challenge message from client {}, closing channel.", + channel.remoteAddress()); + callback.onFailure(new IllegalArgumentException("Unknown challenge message.")); + channel.close(); + } + return; + } + + // Here we have the client challenge, so perform the new auth protocol and set up the channel. + AuthEngine engine = null; + try { + engine = new AuthEngine(challenge.appId, secretKeyHolder.getSecretKey(challenge.appId), conf); + ServerResponse response = engine.respond(challenge); + ByteBuf responseData = Unpooled.buffer(response.encodedLength()); + response.encode(responseData); + callback.onSuccess(responseData.nioBuffer()); + engine.sessionCipher().addToChannel(channel); + } catch (Exception e) { + // This is a fatal error: authentication has failed. Close the channel explicitly. + LOG.debug("Authentication failed for client {}, closing channel.", channel.remoteAddress()); + callback.onFailure(new IllegalArgumentException("Authentication failed.")); + channel.close(); + return; + } finally { + if (engine != null) { + try { + engine.close(); + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + } + + LOG.debug("Authorization successful for client {}.", channel.remoteAddress()); + doDelegate = true; + } + + @Override + public void receive(TransportClient client, ByteBuffer message) { + delegate.receive(client, message); + } + + @Override + public StreamManager getStreamManager() { + return delegate.getStreamManager(); + } + + @Override + public void channelActive(TransportClient client) { + delegate.channelActive(client); + } + + @Override + public void channelInactive(TransportClient client) { + delegate.channelInactive(client); + } + + @Override + public void exceptionCaught(Throwable cause, TransportClient client) { + delegate.exceptionCaught(cause, client); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthServerBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthServerBootstrap.java new file mode 100644 index 0000000000000000000000000000000000000000..77a2a6af4d1343298aa15fe95105ef5a82abdc1c --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthServerBootstrap.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import io.netty.channel.Channel; + +import org.apache.spark.network.sasl.SaslServerBootstrap; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServerBootstrap; +import org.apache.spark.network.util.TransportConf; + +/** + * A bootstrap which is executed on a TransportServer's client channel once a client connects + * to the server, enabling authentication using Spark's auth protocol (and optionally SASL for + * clients that don't support the new protocol). + * + * It also automatically falls back to SASL if the new encryption backend is disabled, so that + * callers only need to install this bootstrap when authentication is enabled. + */ +public class AuthServerBootstrap implements TransportServerBootstrap { + + private final TransportConf conf; + private final SecretKeyHolder secretKeyHolder; + + public AuthServerBootstrap(TransportConf conf, SecretKeyHolder secretKeyHolder) { + this.conf = conf; + this.secretKeyHolder = secretKeyHolder; + } + + public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) { + if (!conf.encryptionEnabled()) { + TransportServerBootstrap sasl = new SaslServerBootstrap(conf, secretKeyHolder); + return sasl.doBootstrap(channel, rpcHandler); + } + + return new AuthRpcHandler(conf, channel, rpcHandler, secretKeyHolder); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java new file mode 100644 index 0000000000000000000000000000000000000000..3312a5bd81a66e29664d48c50ecbe1d3ba76e59b --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.nio.ByteBuffer; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; + +/** + * The client challenge message, used to initiate authentication. + * + * @see README.md + */ +public class ClientChallenge implements Encodable { + /** Serialization tag used to catch incorrect payloads. */ + private static final byte TAG_BYTE = (byte) 0xFA; + + public final String appId; + public final String kdf; + public final int iterations; + public final String cipher; + public final int keyLength; + public final byte[] nonce; + public final byte[] challenge; + + public ClientChallenge( + String appId, + String kdf, + int iterations, + String cipher, + int keyLength, + byte[] nonce, + byte[] challenge) { + this.appId = appId; + this.kdf = kdf; + this.iterations = iterations; + this.cipher = cipher; + this.keyLength = keyLength; + this.nonce = nonce; + this.challenge = challenge; + } + + @Override + public int encodedLength() { + return 1 + 4 + 4 + + Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(kdf) + + Encoders.Strings.encodedLength(cipher) + + Encoders.ByteArrays.encodedLength(nonce) + + Encoders.ByteArrays.encodedLength(challenge); + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(TAG_BYTE); + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, kdf); + buf.writeInt(iterations); + Encoders.Strings.encode(buf, cipher); + buf.writeInt(keyLength); + Encoders.ByteArrays.encode(buf, nonce); + Encoders.ByteArrays.encode(buf, challenge); + } + + public static ClientChallenge decodeMessage(ByteBuffer buffer) { + ByteBuf buf = Unpooled.wrappedBuffer(buffer); + + if (buf.readByte() != TAG_BYTE) { + throw new IllegalArgumentException("Expected ClientChallenge, received something else."); + } + + return new ClientChallenge( + Encoders.Strings.decode(buf), + Encoders.Strings.decode(buf), + buf.readInt(), + Encoders.Strings.decode(buf), + buf.readInt(), + Encoders.ByteArrays.decode(buf), + Encoders.ByteArrays.decode(buf)); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md b/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md new file mode 100644 index 0000000000000000000000000000000000000000..14df703270498877036ae9189803ac3941f10a55 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md @@ -0,0 +1,158 @@ +Spark Auth Protocol and AES Encryption Support +============================================== + +This file describes an auth protocol used by Spark as a more secure alternative to DIGEST-MD5. This +protocol is built on symmetric key encryption, based on the assumption that the two endpoints being +authenticated share a common secret, which is how Spark authentication currently works. The protocol +provides mutual authentication, meaning that after the negotiation both parties know that the remote +side knows the shared secret. The protocol is influenced by the ISO/IEC 9798 protocol, although it's +not an implementation of it. + +This protocol could be replaced with TLS PSK, except no PSK ciphers are available in the currently +released JREs. + +The protocol aims at solving the following shortcomings in Spark's current usage of DIGEST-MD5: + +- MD5 is an aging hash algorithm with known weaknesses, and a more secure alternative is desired. +- DIGEST-MD5 has a pre-defined set of ciphers for which it can generate keys. The only + viable, supported cipher these days is 3DES, and a more modern alternative is desired. +- Encrypting AES session keys with 3DES doesn't solve the issue, since the weakest link + in the negotiation would still be MD5 and 3DES. + +The protocol assumes that the shared secret is generated and distributed in a secure manner. + +The protocol always negotiates encryption keys. If encryption is not desired, the existing +SASL-based authentication, or no authentication at all, can be chosen instead. + +When messages are described below, it's expected that the implementation should support +arbitrary sizes for fields that don't have a fixed size. + +Client Challenge +---------------- + +The auth negotiation is started by the client. The client starts by generating an encryption +key based on the application's shared secret, and a nonce. + + KEY = KDF(SECRET, SALT, KEY_LENGTH) + +Where: +- KDF(): a key derivation function that takes a secret, a salt, a configurable number of + iterations, and a configurable key length. +- SALT: a byte sequence used to salt the key derivation function. +- KEY_LENGTH: length of the encryption key to generate. + + +The client generates a message with the following content: + + CLIENT_CHALLENGE = ( + APP_ID, + KDF, + ITERATIONS, + CIPHER, + KEY_LENGTH, + ANONCE, + ENC(APP_ID || ANONCE || CHALLENGE)) + +Where: + +- APP_ID: the application ID which the server uses to identify the shared secret. +- KDF: the key derivation function described above. +- ITERATIONS: number of iterations to run the KDF when generating keys. +- CIPHER: the cipher used to encrypt data. +- KEY_LENGTH: length of the encryption keys to generate, in bits. +- ANONCE: the nonce used as the salt when generating the auth key. +- ENC(): an encryption function that uses the cipher and the generated key. This function + will also be used in the definition of other messages below. +- CHALLENGE: a byte sequence used as a challenge to the server. +- ||: concatenation operator. + +When strings are used where byte arrays are expected, the UTF-8 representation of the string +is assumed. + +To respond to the challenge, the server should consider the byte array as representing an +arbitrary-length integer, and respond with the value of the integer plus one. + + +Server Response And Challenge +----------------------------- + +Once the client challenge is received, the server will generate the same auth key by +using the same algorithm the client has used. It will then verify the client challenge: +if the APP_ID and ANONCE fields match, the server knows that the client has the shared +secret. The server then creates a response to the client challenge, to prove that it also +has the secret key, and provides parameters to be used when creating the session key. + +The following describes the response from the server: + + SERVER_CHALLENGE = ( + ENC(APP_ID || ANONCE || RESPONSE), + ENC(SNONCE), + ENC(INIV), + ENC(OUTIV)) + +Where: + +- RESPONSE: the server's response to the client challenge. +- SNONCE: a nonce to be used as salt when generating the session key. +- INIV: initialization vector used to initialize the input channel of the client. +- OUTIV: initialization vector used to initialize the output channel of the client. + +At this point the server considers the client to be authenticated, and will try to +decrypt any data further sent by the client using the session key. + + +Default Algorithms +------------------ + +Configuration options are available for the KDF and cipher algorithms to use. + +The default KDF is "PBKDF2WithHmacSHA1". Users should be able to select any algorithm +from those supported by the `javax.crypto.SecretKeyFactory` class, as long as they support +PBEKeySpec when generating keys. The default number of iterations was chosen to take a +reasonable amount of time on modern CPUs. See the documentation in TransportConf for more +details. + +The default cipher algorithm is "AES/CTR/NoPadding". Users should be able to select any +algorithm supported by the commons-crypto library. It should allow the cipher to operate +in stream mode. + +The default key length is 128 (bits). + + +Implementation Details +---------------------- + +The commons-crypto library currently only supports AES ciphers, and requires an initialization +vector (IV). This first version of the protocol does not explicitly include the IV in the client +challenge message. Instead, the IV should be derived from the nonce, including the needed bytes, and +padding the IV with zeroes in case the nonce is not long enough. + +Future versions of the protocol might add support for new ciphers and explicitly include needed +configuration parameters in the messages. + + +Threat Assessment +----------------- + +The protocol is secure against different forms of attack: + +* Eavesdropping: the protocol is built on the assumption that it's computationally infeasible + to calculate the original secret from the encrypted messages. Neither the secret nor any + encryption keys are transmitted on the wire, encrypted or not. + +* Man-in-the-middle: because the protocol performs mutual authentication, both ends need to + know the shared secret to be able to decrypt session data. Even if an attacker is able to insert a + malicious "proxy" between endpoints, the attacker won't be able to read any of the data exchanged + between client and server, nor insert arbitrary commands for the server to execute. + +* Replay attacks: the use of nonces when generating keys prevents an attacker from being able to + just replay messages sniffed from the communication channel. + +An attacker may replay the client challenge and successfully "prove" to a server that it "knows" the +shared secret. But the attacker won't be able to decrypt the server's response, and thus won't be +able to generate a session key, which will make it hard to craft a valid, encrypted message that the +server will be able to understand. This will cause the server to close the connection as soon as the +attacker tries to send any command to the server. The attacker can just hold the channel open for +some time, which will be closed when the server times out the channel. These issues could be +separately mitigated by adding a shorter timeout for the first message after authentication, and +potentially by adding host blacklists if a possible attack is detected from a particular host. diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..affdbf450b1d077d5b526c0667bf30473d70965a --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.nio.ByteBuffer; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; + +/** + * Server's response to client's challenge. + * + * @see README.md + */ +public class ServerResponse implements Encodable { + /** Serialization tag used to catch incorrect payloads. */ + private static final byte TAG_BYTE = (byte) 0xFB; + + public final byte[] response; + public final byte[] nonce; + public final byte[] inputIv; + public final byte[] outputIv; + + public ServerResponse( + byte[] response, + byte[] nonce, + byte[] inputIv, + byte[] outputIv) { + this.response = response; + this.nonce = nonce; + this.inputIv = inputIv; + this.outputIv = outputIv; + } + + @Override + public int encodedLength() { + return 1 + + Encoders.ByteArrays.encodedLength(response) + + Encoders.ByteArrays.encodedLength(nonce) + + Encoders.ByteArrays.encodedLength(inputIv) + + Encoders.ByteArrays.encodedLength(outputIv); + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(TAG_BYTE); + Encoders.ByteArrays.encode(buf, response); + Encoders.ByteArrays.encode(buf, nonce); + Encoders.ByteArrays.encode(buf, inputIv); + Encoders.ByteArrays.encode(buf, outputIv); + } + + public static ServerResponse decodeMessage(ByteBuffer buffer) { + ByteBuf buf = Unpooled.wrappedBuffer(buffer); + + if (buf.readByte() != TAG_BYTE) { + throw new IllegalArgumentException("Expected ServerResponse, received something else."); + } + + return new ServerResponse( + Encoders.ByteArrays.decode(buf), + Encoders.ByteArrays.decode(buf), + Encoders.ByteArrays.decode(buf), + Encoders.ByteArrays.decode(buf)); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java similarity index 64% rename from common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java rename to common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java index 340986a63bbf31b6728cc0b37220020fbd1485b5..7376d1ddc4818b96c570199765af9be779a672f1 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network.sasl.aes; +package org.apache.spark.network.crypto; import java.io.IOException; import java.nio.ByteBuffer; @@ -25,115 +25,91 @@ import java.util.Properties; import javax.crypto.spec.SecretKeySpec; import javax.crypto.spec.IvParameterSpec; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; -import com.google.common.base.Throwables; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.*; import io.netty.util.AbstractReferenceCounted; -import org.apache.commons.crypto.cipher.CryptoCipherFactory; -import org.apache.commons.crypto.random.CryptoRandom; -import org.apache.commons.crypto.random.CryptoRandomFactory; import org.apache.commons.crypto.stream.CryptoInputStream; import org.apache.commons.crypto.stream.CryptoOutputStream; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import org.apache.spark.network.util.ByteArrayReadableChannel; import org.apache.spark.network.util.ByteArrayWritableChannel; -import org.apache.spark.network.util.TransportConf; /** - * AES cipher for encryption and decryption. + * Cipher for encryption and decryption. */ -public class AesCipher { - private static final Logger logger = LoggerFactory.getLogger(AesCipher.class); - public static final String ENCRYPTION_HANDLER_NAME = "AesEncryption"; - public static final String DECRYPTION_HANDLER_NAME = "AesDecryption"; - public static final int STREAM_BUFFER_SIZE = 1024 * 32; - public static final String TRANSFORM = "AES/CTR/NoPadding"; - - private final SecretKeySpec inKeySpec; - private final IvParameterSpec inIvSpec; - private final SecretKeySpec outKeySpec; - private final IvParameterSpec outIvSpec; - private final Properties properties; - - public AesCipher(AesConfigMessage configMessage, TransportConf conf) throws IOException { - this.properties = conf.cryptoConf(); - this.inKeySpec = new SecretKeySpec(configMessage.inKey, "AES"); - this.inIvSpec = new IvParameterSpec(configMessage.inIv); - this.outKeySpec = new SecretKeySpec(configMessage.outKey, "AES"); - this.outIvSpec = new IvParameterSpec(configMessage.outIv); +public class TransportCipher { + @VisibleForTesting + static final String ENCRYPTION_HANDLER_NAME = "TransportEncryption"; + private static final String DECRYPTION_HANDLER_NAME = "TransportDecryption"; + private static final int STREAM_BUFFER_SIZE = 1024 * 32; + + private final Properties conf; + private final String cipher; + private final SecretKeySpec key; + private final byte[] inIv; + private final byte[] outIv; + + public TransportCipher( + Properties conf, + String cipher, + SecretKeySpec key, + byte[] inIv, + byte[] outIv) { + this.conf = conf; + this.cipher = cipher; + this.key = key; + this.inIv = inIv; + this.outIv = outIv; + } + + public String getCipherTransformation() { + return cipher; + } + + @VisibleForTesting + SecretKeySpec getKey() { + return key; + } + + /** The IV for the input channel (i.e. output channel of the remote side). */ + public byte[] getInputIv() { + return inIv; + } + + /** The IV for the output channel (i.e. input channel of the remote side). */ + public byte[] getOutputIv() { + return outIv; } - /** - * Create AES crypto output stream - * @param ch The underlying channel to write out. - * @return Return output crypto stream for encryption. - * @throws IOException - */ private CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException { - return new CryptoOutputStream(TRANSFORM, properties, ch, outKeySpec, outIvSpec); + return new CryptoOutputStream(cipher, conf, ch, key, new IvParameterSpec(outIv)); } - /** - * Create AES crypto input stream - * @param ch The underlying channel used to read data. - * @return Return input crypto stream for decryption. - * @throws IOException - */ private CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException { - return new CryptoInputStream(TRANSFORM, properties, ch, inKeySpec, inIvSpec); + return new CryptoInputStream(cipher, conf, ch, key, new IvParameterSpec(inIv)); } /** - * Add handlers to channel + * Add handlers to channel. + * * @param ch the channel for adding handlers * @throws IOException */ public void addToChannel(Channel ch) throws IOException { ch.pipeline() - .addFirst(ENCRYPTION_HANDLER_NAME, new AesEncryptHandler(this)) - .addFirst(DECRYPTION_HANDLER_NAME, new AesDecryptHandler(this)); - } - - /** - * Create the configuration message - * @param conf is the local transport configuration. - * @return Config message for sending. - */ - public static AesConfigMessage createConfigMessage(TransportConf conf) { - int keySize = conf.aesCipherKeySize(); - Properties properties = conf.cryptoConf(); - - try { - int paramLen = CryptoCipherFactory.getCryptoCipher(AesCipher.TRANSFORM, properties) - .getBlockSize(); - byte[] inKey = new byte[keySize]; - byte[] outKey = new byte[keySize]; - byte[] inIv = new byte[paramLen]; - byte[] outIv = new byte[paramLen]; - - CryptoRandom random = CryptoRandomFactory.getCryptoRandom(properties); - random.nextBytes(inKey); - random.nextBytes(outKey); - random.nextBytes(inIv); - random.nextBytes(outIv); - - return new AesConfigMessage(inKey, inIv, outKey, outIv); - } catch (Exception e) { - logger.error("AES config error", e); - throw Throwables.propagate(e); - } + .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(this)) + .addFirst(DECRYPTION_HANDLER_NAME, new DecryptionHandler(this)); } - private static class AesEncryptHandler extends ChannelOutboundHandlerAdapter { + private static class EncryptionHandler extends ChannelOutboundHandlerAdapter { private final ByteArrayWritableChannel byteChannel; private final CryptoOutputStream cos; - AesEncryptHandler(AesCipher cipher) throws IOException { - byteChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE); + EncryptionHandler(TransportCipher cipher) throws IOException { + byteChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE); cos = cipher.createOutputStream(byteChannel); } @@ -153,11 +129,11 @@ public class AesCipher { } } - private static class AesDecryptHandler extends ChannelInboundHandlerAdapter { + private static class DecryptionHandler extends ChannelInboundHandlerAdapter { private final CryptoInputStream cis; private final ByteArrayReadableChannel byteChannel; - AesDecryptHandler(AesCipher cipher) throws IOException { + DecryptionHandler(TransportCipher cipher) throws IOException { byteChannel = new ByteArrayReadableChannel(); cis = cipher.createInputStream(byteChannel); } @@ -207,7 +183,7 @@ public class AesCipher { this.buf = isByteBuf ? (ByteBuf) msg : null; this.region = isByteBuf ? null : (FileRegion) msg; this.transferred = 0; - this.byteRawChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE); + this.byteRawChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE); this.cos = cos; this.byteEncChannel = ch; } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index a1bb453657460c9417ebfdacada581e06c15d39f..647813772294ee8dbfedf65278ca8647e37373a7 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -30,8 +30,6 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; -import org.apache.spark.network.sasl.aes.AesCipher; -import org.apache.spark.network.sasl.aes.AesConfigMessage; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; @@ -42,24 +40,14 @@ import org.apache.spark.network.util.TransportConf; public class SaslClientBootstrap implements TransportClientBootstrap { private static final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class); - private final boolean encrypt; private final TransportConf conf; private final String appId; private final SecretKeyHolder secretKeyHolder; public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) { - this(conf, appId, secretKeyHolder, false); - } - - public SaslClientBootstrap( - TransportConf conf, - String appId, - SecretKeyHolder secretKeyHolder, - boolean encrypt) { this.conf = conf; this.appId = appId; this.secretKeyHolder = secretKeyHolder; - this.encrypt = encrypt; } /** @@ -69,7 +57,7 @@ public class SaslClientBootstrap implements TransportClientBootstrap { */ @Override public void doBootstrap(TransportClient client, Channel channel) { - SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, encrypt); + SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, conf.saslEncryption()); try { byte[] payload = saslClient.firstToken(); @@ -79,35 +67,19 @@ public class SaslClientBootstrap implements TransportClientBootstrap { msg.encode(buf); buf.writeBytes(msg.body().nioByteBuffer()); - ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.saslRTTimeoutMs()); + ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs()); payload = saslClient.response(JavaUtils.bufferToArray(response)); } client.setClientId(appId); - if (encrypt) { + if (conf.saslEncryption()) { if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) { throw new RuntimeException( new SaslException("Encryption requests by negotiated non-encrypted connection.")); } - if (conf.aesEncryptionEnabled()) { - // Generate a request config message to send to server. - AesConfigMessage configMessage = AesCipher.createConfigMessage(conf); - ByteBuffer buf = configMessage.encodeMessage(); - - // Encrypted the config message. - byte[] toEncrypt = JavaUtils.bufferToArray(buf); - ByteBuffer encrypted = ByteBuffer.wrap(saslClient.wrap(toEncrypt, 0, toEncrypt.length)); - - client.sendRpcSync(encrypted, conf.saslRTTimeoutMs()); - AesCipher cipher = new AesCipher(configMessage, conf); - logger.info("Enabling AES cipher for client channel {}", client); - cipher.addToChannel(channel); - saslClient.dispose(); - } else { - SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize()); - } + SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize()); saslClient = null; logger.debug("Channel {} configured for encryption.", client); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index b2f3ef214b7acafbd7d703cd95e3c19b9c6bb0b2..0231428318addb773969c7910cd24c01a9c103c3 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -29,8 +29,6 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.sasl.aes.AesCipher; -import org.apache.spark.network.sasl.aes.AesConfigMessage; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.util.JavaUtils; @@ -44,7 +42,7 @@ import org.apache.spark.network.util.TransportConf; * Note that the authentication process consists of multiple challenge-response pairs, each of * which are individual RPCs. */ -class SaslRpcHandler extends RpcHandler { +public class SaslRpcHandler extends RpcHandler { private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class); /** Transport configuration. */ @@ -63,7 +61,7 @@ class SaslRpcHandler extends RpcHandler { private boolean isComplete; private boolean isAuthenticated; - SaslRpcHandler( + public SaslRpcHandler( TransportConf conf, Channel channel, RpcHandler delegate, @@ -122,37 +120,10 @@ class SaslRpcHandler extends RpcHandler { return; } - if (!conf.aesEncryptionEnabled()) { - logger.debug("Enabling encryption for channel {}", client); - SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize()); - complete(false); - return; - } - - // Extra negotiation should happen after authentication, so return directly while - // processing authenticate. - if (!isAuthenticated) { - logger.debug("SASL authentication successful for channel {}", client); - isAuthenticated = true; - return; - } - - // Create AES cipher when it is authenticated - try { - byte[] encrypted = JavaUtils.bufferToArray(message); - ByteBuffer decrypted = ByteBuffer.wrap(saslServer.unwrap(encrypted, 0 , encrypted.length)); - - AesConfigMessage configMessage = AesConfigMessage.decodeMessage(decrypted); - AesCipher cipher = new AesCipher(configMessage, conf); - - // Send response back to client to confirm that server accept config. - callback.onSuccess(JavaUtils.stringToBytes(AesCipher.TRANSFORM)); - logger.info("Enabling AES cipher for Server channel {}", client); - cipher.addToChannel(channel); - complete(true); - } catch (IOException ioe) { - throw new RuntimeException(ioe); - } + logger.debug("Enabling encryption for channel {}", client); + SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize()); + complete(false); + return; } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java deleted file mode 100644 index 3ef6f74a1f89f8701295c6e529e0a90d4950fb2f..0000000000000000000000000000000000000000 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.sasl.aes; - -import java.nio.ByteBuffer; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; - -import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.protocol.Encoders; - -/** - * The AES cipher options for encryption negotiation. - */ -public class AesConfigMessage implements Encodable { - /** Serialization tag used to catch incorrect payloads. */ - private static final byte TAG_BYTE = (byte) 0xEB; - - public byte[] inKey; - public byte[] outKey; - public byte[] inIv; - public byte[] outIv; - - public AesConfigMessage(byte[] inKey, byte[] inIv, byte[] outKey, byte[] outIv) { - if (inKey == null || inIv == null || outKey == null || outIv == null) { - throw new IllegalArgumentException("Cipher Key or IV must not be null!"); - } - - this.inKey = inKey; - this.inIv = inIv; - this.outKey = outKey; - this.outIv = outIv; - } - - @Override - public int encodedLength() { - return 1 + - Encoders.ByteArrays.encodedLength(inKey) + Encoders.ByteArrays.encodedLength(outKey) + - Encoders.ByteArrays.encodedLength(inIv) + Encoders.ByteArrays.encodedLength(outIv); - } - - @Override - public void encode(ByteBuf buf) { - buf.writeByte(TAG_BYTE); - Encoders.ByteArrays.encode(buf, inKey); - Encoders.ByteArrays.encode(buf, inIv); - Encoders.ByteArrays.encode(buf, outKey); - Encoders.ByteArrays.encode(buf, outIv); - } - - /** - * Encode the config message. - * @return ByteBuffer which contains encoded config message. - */ - public ByteBuffer encodeMessage(){ - ByteBuffer buf = ByteBuffer.allocate(encodedLength()); - - ByteBuf wrappedBuf = Unpooled.wrappedBuffer(buf); - wrappedBuf.clear(); - encode(wrappedBuf); - - return buf; - } - - /** - * Decode the config message from buffer - * @param buffer the buffer contain encoded config message - * @return config message - */ - public static AesConfigMessage decodeMessage(ByteBuffer buffer) { - ByteBuf buf = Unpooled.wrappedBuffer(buffer); - - if (buf.readByte() != TAG_BYTE) { - throw new IllegalStateException("Expected AesConfigMessage, received something else" - + " (maybe your client does not have AES enabled?)"); - } - - byte[] outKey = Encoders.ByteArrays.decode(buf); - byte[] outIv = Encoders.ByteArrays.decode(buf); - byte[] inKey = Encoders.ByteArrays.decode(buf); - byte[] inIv = Encoders.ByteArrays.decode(buf); - return new AesConfigMessage(inKey, inIv, outKey, outIv); - } - -} diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 6a557fa75d06425d9022b6097a1c6e1daa72be15..c226d8f3bc8fafb296fcc7c3ba108b120b9122bc 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -117,9 +117,10 @@ public class TransportConf { /** Send buffer size (SO_SNDBUF). */ public int sendBuf() { return conf.getInt(SPARK_NETWORK_IO_SENDBUFFER_KEY, -1); } - /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ - public int saslRTTimeoutMs() { - return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_SASL_TIMEOUT_KEY, "30s")) * 1000; + /** Timeout for a single round trip of auth message exchange, in milliseconds. */ + public int authRTTimeoutMs() { + return (int) JavaUtils.timeStringAsSec(conf.get("spark.network.auth.rpcTimeout", + conf.get(SPARK_NETWORK_SASL_TIMEOUT_KEY, "30s"))) * 1000; } /** @@ -162,40 +163,95 @@ public class TransportConf { } /** - * Maximum number of bytes to be encrypted at a time when SASL encryption is enabled. + * Enables strong encryption. Also enables the new auth protocol, used to negotiate keys. */ - public int maxSaslEncryptedBlockSize() { - return Ints.checkedCast(JavaUtils.byteStringAsBytes( - conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k"))); + public boolean encryptionEnabled() { + return conf.getBoolean("spark.network.crypto.enabled", false); } /** - * Whether the server should enforce encryption on SASL-authenticated connections. + * The cipher transformation to use for encrypting session data. */ - public boolean saslServerAlwaysEncrypt() { - return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false); + public String cipherTransformation() { + return conf.get("spark.network.crypto.cipher", "AES/CTR/NoPadding"); + } + + /** + * The key generation algorithm. This should be an algorithm that accepts a "PBEKeySpec" + * as input. The default value (PBKDF2WithHmacSHA1) is available in Java 7. + */ + public String keyFactoryAlgorithm() { + return conf.get("spark.network.crypto.keyFactoryAlgorithm", "PBKDF2WithHmacSHA1"); + } + + /** + * How many iterations to run when generating keys. + * + * See some discussion about this at: http://security.stackexchange.com/q/3959 + * The default value was picked for speed, since it assumes that the secret has good entropy + * (128 bits by default), which is not generally the case with user passwords. + */ + public int keyFactoryIterations() { + return conf.getInt("spark.networy.crypto.keyFactoryIterations", 1024); + } + + /** + * Encryption key length, in bits. + */ + public int encryptionKeyLength() { + return conf.getInt("spark.network.crypto.keyLength", 128); + } + + /** + * Initial vector length, in bytes. + */ + public int ivLength() { + return conf.getInt("spark.network.crypto.ivLength", 16); + } + + /** + * The algorithm for generated secret keys. Nobody should really need to change this, + * but configurable just in case. + */ + public String keyAlgorithm() { + return conf.get("spark.network.crypto.keyAlgorithm", "AES"); + } + + /** + * Whether to fall back to SASL if the new auth protocol fails. Enabled by default for + * backwards compatibility. + */ + public boolean saslFallback() { + return conf.getBoolean("spark.network.crypto.saslFallback", true); } /** - * The trigger for enabling AES encryption. + * Whether to enable SASL-based encryption when authenticating using SASL. */ - public boolean aesEncryptionEnabled() { - return conf.getBoolean("spark.network.aes.enabled", false); + public boolean saslEncryption() { + return conf.getBoolean("spark.authenticate.enableSaslEncryption", false); } /** - * The key size to use when AES cipher is enabled. Notice that the length should be 16, 24 or 32 - * bytes. + * Maximum number of bytes to be encrypted at a time when SASL encryption is used. */ - public int aesCipherKeySize() { - return conf.getInt("spark.network.aes.keySize", 16); + public int maxSaslEncryptedBlockSize() { + return Ints.checkedCast(JavaUtils.byteStringAsBytes( + conf.get("spark.network.sasl.maxEncryptedBlockSize", "64k"))); + } + + /** + * Whether the server should enforce encryption on SASL-authenticated connections. + */ + public boolean saslServerAlwaysEncrypt() { + return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false); } /** * The commons-crypto configuration for the module. */ public Properties cryptoConf() { - return CryptoUtils.toCryptoConf("spark.network.aes.config.", conf.getAll()); + return CryptoUtils.toCryptoConf("spark.network.crypto.config.", conf.getAll()); } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..9a186f211312f587d4d704a35acad26ea137aec9 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.util.Arrays; +import java.util.Map; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.collect.ImmutableMap; +import org.junit.BeforeClass; +import org.junit.Test; +import static org.junit.Assert.*; + +import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class AuthEngineSuite { + + private static TransportConf conf; + + @BeforeClass + public static void setUp() { + conf = new TransportConf("rpc", MapConfigProvider.EMPTY); + } + + @Test + public void testAuthEngine() throws Exception { + AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "secret", conf); + + try { + ClientChallenge clientChallenge = client.challenge(); + ServerResponse serverResponse = server.respond(clientChallenge); + client.validate(serverResponse); + + TransportCipher serverCipher = server.sessionCipher(); + TransportCipher clientCipher = client.sessionCipher(); + + assertTrue(Arrays.equals(serverCipher.getInputIv(), clientCipher.getOutputIv())); + assertTrue(Arrays.equals(serverCipher.getOutputIv(), clientCipher.getInputIv())); + assertEquals(serverCipher.getKey(), clientCipher.getKey()); + } finally { + client.close(); + server.close(); + } + } + + @Test + public void testMismatchedSecret() throws Exception { + AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "different_secret", conf); + + ClientChallenge clientChallenge = client.challenge(); + try { + server.respond(clientChallenge); + fail("Should have failed to validate response."); + } catch (IllegalArgumentException e) { + // Expected. + } + } + + @Test(expected = IllegalArgumentException.class) + public void testWrongAppId() throws Exception { + AuthEngine engine = new AuthEngine("appId", "secret", conf); + ClientChallenge challenge = engine.challenge(); + + byte[] badChallenge = engine.challenge(new byte[] { 0x00 }, challenge.nonce, + engine.rawResponse(engine.challenge)); + engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations, + challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge)); + } + + @Test(expected = IllegalArgumentException.class) + public void testWrongNonce() throws Exception { + AuthEngine engine = new AuthEngine("appId", "secret", conf); + ClientChallenge challenge = engine.challenge(); + + byte[] badChallenge = engine.challenge(challenge.appId.getBytes(UTF_8), new byte[] { 0x00 }, + engine.rawResponse(engine.challenge)); + engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations, + challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge)); + } + + @Test(expected = IllegalArgumentException.class) + public void testBadChallenge() throws Exception { + AuthEngine engine = new AuthEngine("appId", "secret", conf); + ClientChallenge challenge = engine.challenge(); + + byte[] badChallenge = new byte[challenge.challenge.length]; + engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations, + challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge)); + } + +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..21609d5aa2a2098441fd7ec432303db42595b1c4 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import io.netty.channel.Channel; +import org.junit.After; +import org.junit.Test; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.sasl.SaslServerBootstrap; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.TransportServerBootstrap; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class AuthIntegrationSuite { + + private AuthTestCtx ctx; + + @After + public void cleanUp() throws Exception { + if (ctx != null) { + ctx.close(); + } + ctx = null; + } + + @Test + public void testNewAuth() throws Exception { + ctx = new AuthTestCtx(); + ctx.createServer("secret"); + ctx.createClient("secret"); + + ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000); + assertEquals("Pong", JavaUtils.bytesToString(reply)); + assertTrue(ctx.authRpcHandler.doDelegate); + assertFalse(ctx.authRpcHandler.delegate instanceof SaslRpcHandler); + } + + @Test + public void testAuthFailure() throws Exception { + ctx = new AuthTestCtx(); + ctx.createServer("server"); + + try { + ctx.createClient("client"); + fail("Should have failed to create client."); + } catch (Exception e) { + assertFalse(ctx.authRpcHandler.doDelegate); + assertFalse(ctx.serverChannel.isActive()); + } + } + + @Test + public void testSaslServerFallback() throws Exception { + ctx = new AuthTestCtx(); + ctx.createServer("secret", true); + ctx.createClient("secret", false); + + ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000); + assertEquals("Pong", JavaUtils.bytesToString(reply)); + } + + @Test + public void testSaslClientFallback() throws Exception { + ctx = new AuthTestCtx(); + ctx.createServer("secret", false); + ctx.createClient("secret", true); + + ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000); + assertEquals("Pong", JavaUtils.bytesToString(reply)); + } + + @Test + public void testAuthReplay() throws Exception { + // This test covers the case where an attacker replays a challenge message sniffed from the + // network, but doesn't know the actual secret. The server should close the connection as + // soon as a message is sent after authentication is performed. This is emulated by removing + // the client encryption handler after authentication. + ctx = new AuthTestCtx(); + ctx.createServer("secret"); + ctx.createClient("secret"); + + assertNotNull(ctx.client.getChannel().pipeline() + .remove(TransportCipher.ENCRYPTION_HANDLER_NAME)); + + try { + ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000); + fail("Should have failed unencrypted RPC."); + } catch (Exception e) { + assertTrue(ctx.authRpcHandler.doDelegate); + } + } + + private class AuthTestCtx { + + private final String appId = "testAppId"; + private final TransportConf conf; + private final TransportContext ctx; + + TransportClient client; + TransportServer server; + volatile Channel serverChannel; + volatile AuthRpcHandler authRpcHandler; + + AuthTestCtx() throws Exception { + Map<String, String> testConf = ImmutableMap.of("spark.network.crypto.enabled", "true"); + this.conf = new TransportConf("rpc", new MapConfigProvider(testConf)); + + RpcHandler rpcHandler = new RpcHandler() { + @Override + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + assertEquals("Ping", JavaUtils.bytesToString(message)); + callback.onSuccess(JavaUtils.stringToBytes("Pong")); + } + + @Override + public StreamManager getStreamManager() { + return null; + } + }; + + this.ctx = new TransportContext(conf, rpcHandler); + } + + void createServer(String secret) throws Exception { + createServer(secret, true); + } + + void createServer(String secret, boolean enableAes) throws Exception { + TransportServerBootstrap introspector = new TransportServerBootstrap() { + @Override + public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) { + AuthTestCtx.this.serverChannel = channel; + if (rpcHandler instanceof AuthRpcHandler) { + AuthTestCtx.this.authRpcHandler = (AuthRpcHandler) rpcHandler; + } + return rpcHandler; + } + }; + SecretKeyHolder keyHolder = createKeyHolder(secret); + TransportServerBootstrap auth = enableAes ? new AuthServerBootstrap(conf, keyHolder) + : new SaslServerBootstrap(conf, keyHolder); + this.server = ctx.createServer(Lists.newArrayList(auth, introspector)); + } + + void createClient(String secret) throws Exception { + createClient(secret, true); + } + + void createClient(String secret, boolean enableAes) throws Exception { + TransportConf clientConf = enableAes ? conf + : new TransportConf("rpc", MapConfigProvider.EMPTY); + List<TransportClientBootstrap> bootstraps = Lists.<TransportClientBootstrap>newArrayList( + new AuthClientBootstrap(clientConf, appId, createKeyHolder(secret))); + this.client = ctx.createClientFactory(bootstraps) + .createClient(TestUtils.getLocalHost(), server.getPort()); + } + + void close() { + if (client != null) { + client.close(); + } + if (server != null) { + server.close(); + } + } + + private SecretKeyHolder createKeyHolder(String secret) { + SecretKeyHolder keyHolder = mock(SecretKeyHolder.class); + when(keyHolder.getSaslUser(anyString())).thenReturn(appId); + when(keyHolder.getSecretKey(anyString())).thenReturn(secret); + return keyHolder; + } + + } + +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthMessagesSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthMessagesSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..a90ff247da4fc7273db3577e83e5776cd1dee157 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthMessagesSuite.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.nio.ByteBuffer; +import java.util.Arrays; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.junit.Test; +import static org.junit.Assert.*; + +import org.apache.spark.network.protocol.Encodable; + +public class AuthMessagesSuite { + + private static int COUNTER = 0; + + private static String string() { + return String.valueOf(COUNTER++); + } + + private static byte[] byteArray() { + byte[] bytes = new byte[COUNTER++]; + for (int i = 0; i < bytes.length; i++) { + bytes[i] = (byte) COUNTER; + } return bytes; + } + + private static int integer() { + return COUNTER++; + } + + @Test + public void testClientChallenge() { + ClientChallenge msg = new ClientChallenge(string(), string(), integer(), string(), integer(), + byteArray(), byteArray()); + ClientChallenge decoded = ClientChallenge.decodeMessage(encode(msg)); + + assertEquals(msg.appId, decoded.appId); + assertEquals(msg.kdf, decoded.kdf); + assertEquals(msg.iterations, decoded.iterations); + assertEquals(msg.cipher, decoded.cipher); + assertEquals(msg.keyLength, decoded.keyLength); + assertTrue(Arrays.equals(msg.nonce, decoded.nonce)); + assertTrue(Arrays.equals(msg.challenge, decoded.challenge)); + } + + @Test + public void testServerResponse() { + ServerResponse msg = new ServerResponse(byteArray(), byteArray(), byteArray(), byteArray()); + ServerResponse decoded = ServerResponse.decodeMessage(encode(msg)); + assertTrue(Arrays.equals(msg.response, decoded.response)); + assertTrue(Arrays.equals(msg.nonce, decoded.nonce)); + assertTrue(Arrays.equals(msg.inputIv, decoded.inputIv)); + assertTrue(Arrays.equals(msg.outputIv, decoded.outputIv)); + } + + private ByteBuffer encode(Encodable msg) { + ByteBuf buf = Unpooled.buffer(); + msg.encode(buf); + return buf.nioBuffer(); + } + +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index e27301f49e34b88495caba9607ed591348224936..87129b900bf0b6c8517d46a829702d528f590043 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -56,7 +56,6 @@ import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; -import org.apache.spark.network.sasl.aes.AesCipher; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; @@ -153,7 +152,7 @@ public class SparkSaslSuite { .when(rpcHandler) .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class)); - SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false, false); + SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false); try { ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), TimeUnit.SECONDS.toMillis(10)); @@ -279,7 +278,7 @@ public class SparkSaslSuite { new Random().nextBytes(data); Files.write(data, file); - ctx = new SaslTestCtx(rpcHandler, true, false, false, testConf); + ctx = new SaslTestCtx(rpcHandler, true, false, testConf); final CountDownLatch lock = new CountDownLatch(1); @@ -317,7 +316,7 @@ public class SparkSaslSuite { public void testServerAlwaysEncrypt() throws Exception { SaslTestCtx ctx = null; try { - ctx = new SaslTestCtx(mock(RpcHandler.class), false, false, false, + ctx = new SaslTestCtx(mock(RpcHandler.class), false, false, ImmutableMap.of("spark.network.sasl.serverAlwaysEncrypt", "true")); fail("Should have failed to connect without encryption."); } catch (Exception e) { @@ -336,7 +335,7 @@ public class SparkSaslSuite { // able to understand RPCs sent to it and thus close the connection. SaslTestCtx ctx = null; try { - ctx = new SaslTestCtx(mock(RpcHandler.class), true, true, false); + ctx = new SaslTestCtx(mock(RpcHandler.class), true, true); ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), TimeUnit.SECONDS.toMillis(10)); fail("Should have failed to send RPC to server."); @@ -374,69 +373,6 @@ public class SparkSaslSuite { } } - @Test - public void testAesEncryption() throws Exception { - final AtomicReference<ManagedBuffer> response = new AtomicReference<>(); - final File file = File.createTempFile("sasltest", ".txt"); - SaslTestCtx ctx = null; - try { - final TransportConf conf = new TransportConf("rpc", MapConfigProvider.EMPTY); - final TransportConf spyConf = spy(conf); - doReturn(true).when(spyConf).aesEncryptionEnabled(); - - StreamManager sm = mock(StreamManager.class); - when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer<ManagedBuffer>() { - @Override - public ManagedBuffer answer(InvocationOnMock invocation) { - return new FileSegmentManagedBuffer(spyConf, file, 0, file.length()); - } - }); - - RpcHandler rpcHandler = mock(RpcHandler.class); - when(rpcHandler.getStreamManager()).thenReturn(sm); - - byte[] data = new byte[256 * 1024 * 1024]; - new Random().nextBytes(data); - Files.write(data, file); - - ctx = new SaslTestCtx(rpcHandler, true, false, true); - - final Object lock = new Object(); - - ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - doAnswer(new Answer<Void>() { - @Override - public Void answer(InvocationOnMock invocation) { - response.set((ManagedBuffer) invocation.getArguments()[1]); - response.get().retain(); - synchronized (lock) { - lock.notifyAll(); - } - return null; - } - }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class)); - - synchronized (lock) { - ctx.client.fetchChunk(0, 0, callback); - lock.wait(10 * 1000); - } - - verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class)); - verify(callback, never()).onFailure(anyInt(), any(Throwable.class)); - - byte[] received = ByteStreams.toByteArray(response.get().createInputStream()); - assertTrue(Arrays.equals(data, received)); - } finally { - file.delete(); - if (ctx != null) { - ctx.close(); - } - if (response.get() != null) { - response.get().release(); - } - } - } - private static class SaslTestCtx { final TransportClient client; @@ -449,46 +385,39 @@ public class SparkSaslSuite { SaslTestCtx( RpcHandler rpcHandler, boolean encrypt, - boolean disableClientEncryption, - boolean aesEnable) + boolean disableClientEncryption) throws Exception { - this(rpcHandler, encrypt, disableClientEncryption, aesEnable, - Collections.<String, String>emptyMap()); + this(rpcHandler, encrypt, disableClientEncryption, Collections.<String, String>emptyMap()); } SaslTestCtx( RpcHandler rpcHandler, boolean encrypt, boolean disableClientEncryption, - boolean aesEnable, - Map<String, String> testConf) + Map<String, String> extraConf) throws Exception { + Map<String, String> testConf = ImmutableMap.<String, String>builder() + .putAll(extraConf) + .put("spark.authenticate.enableSaslEncryption", String.valueOf(encrypt)) + .build(); TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(testConf)); - if (aesEnable) { - conf = spy(conf); - doReturn(true).when(conf).aesEncryptionEnabled(); - } - SecretKeyHolder keyHolder = mock(SecretKeyHolder.class); when(keyHolder.getSaslUser(anyString())).thenReturn("user"); when(keyHolder.getSecretKey(anyString())).thenReturn("secret"); TransportContext ctx = new TransportContext(conf, rpcHandler); - String encryptHandlerName = aesEnable ? AesCipher.ENCRYPTION_HANDLER_NAME : - SaslEncryption.ENCRYPTION_HANDLER_NAME; - - this.checker = new EncryptionCheckerBootstrap(encryptHandlerName); + this.checker = new EncryptionCheckerBootstrap(SaslEncryption.ENCRYPTION_HANDLER_NAME); this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder), checker)); try { List<TransportClientBootstrap> clientBootstraps = Lists.newArrayList(); - clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder, encrypt)); + clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder)); if (disableClientEncryption) { clientBootstraps.add(new EncryptionDisablerBootstrap()); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 772fb88325b35bda3d660bf189cc6bd6b9ec034c..616505d9796d05b172c8bdb6150c0abc843b358c 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -21,7 +21,6 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.List; -import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -30,7 +29,7 @@ import org.apache.spark.network.TransportContext; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.sasl.SaslClientBootstrap; +import org.apache.spark.network.crypto.AuthClientBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; @@ -47,8 +46,7 @@ public class ExternalShuffleClient extends ShuffleClient { private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class); private final TransportConf conf; - private final boolean saslEnabled; - private final boolean saslEncryptionEnabled; + private final boolean authEnabled; private final SecretKeyHolder secretKeyHolder; protected TransportClientFactory clientFactory; @@ -61,15 +59,10 @@ public class ExternalShuffleClient extends ShuffleClient { public ExternalShuffleClient( TransportConf conf, SecretKeyHolder secretKeyHolder, - boolean saslEnabled, - boolean saslEncryptionEnabled) { - Preconditions.checkArgument( - !saslEncryptionEnabled || saslEnabled, - "SASL encryption can only be enabled if SASL is also enabled."); + boolean authEnabled) { this.conf = conf; this.secretKeyHolder = secretKeyHolder; - this.saslEnabled = saslEnabled; - this.saslEncryptionEnabled = saslEncryptionEnabled; + this.authEnabled = authEnabled; } protected void checkInit() { @@ -81,8 +74,8 @@ public class ExternalShuffleClient extends ShuffleClient { this.appId = appId; TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); List<TransportClientBootstrap> bootstraps = Lists.newArrayList(); - if (saslEnabled) { - bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder, saslEncryptionEnabled)); + if (authEnabled) { + bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); } clientFactory = context.createClientFactory(bootstraps); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java index 42cedd9943150a95c25acb1b587ca9bea7ca7c83..ab49b1c1d789694338c2ee3b02a9fe389e7dc146 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java @@ -60,9 +60,8 @@ public class MesosExternalShuffleClient extends ExternalShuffleClient { public MesosExternalShuffleClient( TransportConf conf, SecretKeyHolder secretKeyHolder, - boolean saslEnabled, - boolean saslEncryptionEnabled) { - super(conf, secretKeyHolder, saslEnabled, saslEncryptionEnabled); + boolean authEnabled) { + super(conf, secretKeyHolder, authEnabled); } public void registerDriverWithShuffleService( diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 8dd97b29eb36807b4ee89fc4564acd487d5eccc0..9248ef3c467dfd38704a23ce797fd637b3e962d7 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -133,7 +133,7 @@ public class ExternalShuffleIntegrationSuite { final Semaphore requestsRemaining = new Semaphore(0); - ExternalShuffleClient client = new ExternalShuffleClient(clientConf, null, false, false); + ExternalShuffleClient client = new ExternalShuffleClient(clientConf, null, false); client.init(APP_ID); client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, new BlockFetchingListener() { @@ -243,7 +243,7 @@ public class ExternalShuffleIntegrationSuite { private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) throws IOException { - ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false); + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); client.init(APP_ID); client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), executorId, executorInfo); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index aed25a161e17ea9c8155bae4a4e53fdb68c55f59..4ae75a1b1762a979333d713feb1faa3f7358bdf4 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -20,6 +20,7 @@ package org.apache.spark.network.shuffle; import java.io.IOException; import java.util.Arrays; +import com.google.common.collect.ImmutableMap; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -88,8 +89,14 @@ public class ExternalShuffleSecuritySuite { /** Creates an ExternalShuffleClient and attempts to register with the server. */ private void validate(String appId, String secretKey, boolean encrypt) throws IOException { + TransportConf testConf = conf; + if (encrypt) { + testConf = new TransportConf("shuffle", new MapConfigProvider( + ImmutableMap.of("spark.authenticate.enableSaslEncryption", "true"))); + } + ExternalShuffleClient client = - new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true, encrypt); + new ExternalShuffleClient(testConf, new TestSecretKeyHolder(appId, secretKey), true); client.init(appId); // Registration either succeeds or throws an exception. client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0", diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index ea726e3c8240ea141ad41f1e53a10d1a312c30c7..c7620d0fe128806b647b00e3abe6c42385f8bd07 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -45,7 +45,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.TransportContext; -import org.apache.spark.network.sasl.SaslServerBootstrap; +import org.apache.spark.network.crypto.AuthServerBootstrap; import org.apache.spark.network.sasl.ShuffleSecretManager; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.TransportServerBootstrap; @@ -172,7 +172,7 @@ public class YarnShuffleService extends AuxiliaryService { boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); if (authEnabled) { createSecretManager(); - bootstraps.add(new SaslServerBootstrap(transportConf, secretManager)); + bootstraps.add(new AuthServerBootstrap(transportConf, secretManager)); } int port = conf.getInt( diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 9bdc5096b6afd0a9b5c6bc119902b155e0fe2cee..cde768281f1197870308e1273a2be3d8cca74157 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.io.Text import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.network.sasl.SecretKeyHolder import org.apache.spark.util.Utils @@ -191,7 +192,7 @@ private[spark] class SecurityManager( // allow all users/groups to have view/modify permissions private val WILDCARD_ACL = "*" - private val authOn = sparkConf.getBoolean(SecurityManager.SPARK_AUTH_CONF, false) + private val authOn = sparkConf.get(NETWORK_AUTH_ENABLED) // keep spark.ui.acls.enable for backwards compatibility with 1.0 private var aclsOn = sparkConf.getBoolean("spark.acls.enable", sparkConf.getBoolean("spark.ui.acls.enable", false)) @@ -516,11 +517,11 @@ private[spark] class SecurityManager( def isAuthenticationEnabled(): Boolean = authOn /** - * Checks whether SASL encryption should be enabled. - * @return Whether to enable SASL encryption when connecting to services that support it. + * Checks whether network encryption should be enabled. + * @return Whether to enable encryption when connecting to services that support it. */ - def isSaslEncryptionEnabled(): Boolean = { - sparkConf.getBoolean("spark.authenticate.enableSaslEncryption", false) + def isEncryptionEnabled(): Boolean = { + sparkConf.get(NETWORK_ENCRYPTION_ENABLED) || sparkConf.get(SASL_ENCRYPTION_ENABLED) } /** diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 601d24191eec87bae09498977f92a2810992705e..308a1ed5fa963e71ad9aefe7a24b8e4bd0cb873d 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -607,6 +607,10 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria "\"client\".") } } + + val encryptionEnabled = get(NETWORK_ENCRYPTION_ENABLED) || get(SASL_ENCRYPTION_ENABLED) + require(!encryptionEnabled || get(NETWORK_AUTH_ENABLED), + s"${NETWORK_AUTH_ENABLED.key} must be enabled when enabling encryption.") } /** @@ -726,6 +730,7 @@ private[spark] object SparkConf extends Logging { (name.startsWith("spark.auth") && name != SecurityManager.SPARK_AUTH_SECRET_CONF) || name.startsWith("spark.ssl") || name.startsWith("spark.rpc") || + name.startsWith("spark.network") || isSparkPortConf(name) } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 1296386ac9bd3ad7f4062c292088fd991bfcb852..539dbb55eeff0922e2f34ce0b17048f95c8a0aa8 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -235,7 +235,7 @@ object SparkEnv extends Logging { val securityManager = new SecurityManager(conf, ioEncryptionKey) ioEncryptionKey.foreach { _ => - if (!securityManager.isSaslEncryptionEnabled()) { + if (!securityManager.isEncryptionEnabled()) { logWarning("I/O encryption enabled without RPC encryption: keys will be visible on the " + "wire.") } diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index 13eadbe44f612fd18d36bb75c19f91f5c0eda67b..8d491ddf6e09216d598c03e7c1807fcbf9120555 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -25,8 +25,8 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.TransportContext +import org.apache.spark.network.crypto.AuthServerBootstrap import org.apache.spark.network.netty.SparkTransportConf -import org.apache.spark.network.sasl.SaslServerBootstrap import org.apache.spark.network.server.{TransportServer, TransportServerBootstrap} import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler import org.apache.spark.network.util.TransportConf @@ -47,7 +47,6 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false) private val port = sparkConf.getInt("spark.shuffle.service.port", 7337) - private val useSasl: Boolean = securityManager.isAuthenticationEnabled() private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, "shuffle", numUsableCores = 0) @@ -74,10 +73,11 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana /** Start the external shuffle service */ def start() { require(server == null, "Shuffle server already started") - logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") + val authEnabled = securityManager.isAuthenticationEnabled() + logInfo(s"Starting shuffle service on port $port (auth enabled = $authEnabled)") val bootstraps: Seq[TransportServerBootstrap] = - if (useSasl) { - Seq(new SaslServerBootstrap(transportConf, securityManager)) + if (authEnabled) { + Seq(new AuthServerBootstrap(transportConf, securityManager)) } else { Nil } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index aba429bcdca60e8fc0cae1107b9e1c94473c3217..536f493b417228d30fb8f794218a9c24fb2c44d3 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -243,4 +243,20 @@ package object config { "and event logs.") .stringConf .createWithDefault("(?i)secret|password") + + private[spark] val NETWORK_AUTH_ENABLED = + ConfigBuilder("spark.authenticate") + .booleanConf + .createWithDefault(false) + + private[spark] val SASL_ENCRYPTION_ENABLED = + ConfigBuilder("spark.authenticate.enableSaslEncryption") + .booleanConf + .createWithDefault(false) + + private[spark] val NETWORK_ENCRYPTION_ENABLED = + ConfigBuilder("spark.network.crypto.enabled") + .booleanConf + .createWithDefault(false) + } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 3d4ea3cccc934e18b127d5d0c1c5f5edd808a776..b75e91b660969f0da610020d79e5d47c32f3beb1 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -27,7 +27,7 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} -import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} +import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher} import org.apache.spark.network.shuffle.protocol.UploadBlock @@ -63,9 +63,8 @@ private[spark] class NettyBlockTransferService( var serverBootstrap: Option[TransportServerBootstrap] = None var clientBootstrap: Option[TransportClientBootstrap] = None if (authEnabled) { - serverBootstrap = Some(new SaslServerBootstrap(transportConf, securityManager)) - clientBootstrap = Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager, - securityManager.isSaslEncryptionEnabled())) + serverBootstrap = Some(new AuthServerBootstrap(transportConf, securityManager)) + clientBootstrap = Some(new AuthClientBootstrap(transportConf, conf.getAppId, securityManager)) } transportContext = new TransportContext(transportConf, rpcHandler) clientFactory = transportContext.createClientFactory(clientBootstrap.toSeq.asJava) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index e56943da1303a3ad17e88d6560b8263abbab8b52..1e448b2f1a5c69656e52652ff9fd31dbbd72d12b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -33,8 +33,8 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging import org.apache.spark.network.TransportContext import org.apache.spark.network.client._ +import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.netty.SparkTransportConf -import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} import org.apache.spark.network.server._ import org.apache.spark.rpc._ import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance} @@ -60,8 +60,8 @@ private[netty] class NettyRpcEnv( private def createClientBootstraps(): java.util.List[TransportClientBootstrap] = { if (securityManager.isAuthenticationEnabled()) { - java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, - securityManager.isSaslEncryptionEnabled())) + java.util.Arrays.asList(new AuthClientBootstrap(transportConf, + securityManager.getSaslUser(), securityManager)) } else { java.util.Collections.emptyList[TransportClientBootstrap] } @@ -111,7 +111,7 @@ private[netty] class NettyRpcEnv( def startServer(bindAddress: String, port: Int): Unit = { val bootstraps: java.util.List[TransportServerBootstrap] = if (securityManager.isAuthenticationEnabled()) { - java.util.Arrays.asList(new SaslServerBootstrap(transportConf, securityManager)) + java.util.Arrays.asList(new AuthServerBootstrap(transportConf, securityManager)) } else { java.util.Collections.emptyList() } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 04521c9159eacf97931367e20d1616959bf54e11..c40186756f2d5a40515dc6f6b7fcdef0ebe05b09 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -125,8 +125,7 @@ private[spark] class BlockManager( // standard BlockTransferService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) - new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(), - securityManager.isSaslEncryptionEnabled()) + new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled()) } else { blockTransferService } diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 83906cff123bfc277351c9b4abfb4b4ac3c56e1b..0897891ee17584565a3addbbe455c95f80edd187 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -303,6 +303,25 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst } } + test("encryption requires authentication") { + val conf = new SparkConf() + conf.validateSettings() + + conf.set(NETWORK_ENCRYPTION_ENABLED, true) + intercept[IllegalArgumentException] { + conf.validateSettings() + } + + conf.set(NETWORK_ENCRYPTION_ENABLED, false) + conf.set(SASL_ENCRYPTION_ENABLED, true) + intercept[IllegalArgumentException] { + conf.validateSettings() + } + + conf.set(NETWORK_AUTH_ENABLED, true) + conf.validateSettings() + } + } class Class1 {} diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 022fe91edade9e6a6a8ab6fe2db722b5d92598cc..fe8955840d72f2c4192bd3234c0409a7ced5b231 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -94,6 +94,20 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi } } + test("security with aes encryption") { + val conf = new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + .set("spark.network.crypto.enabled", "true") + .set("spark.network.crypto.saslFallback", "false") + testConnection(conf, conf) match { + case Success(_) => // expected + case Failure(t) => fail(t) + } + } + + /** * Creates two servers with different configurations and sees if they can talk. * Returns Success() if they can transfer a block, and Failure() if the block transfer was failed diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index acdf21df9a16157b3f1be95945fb3a94a6ffdead..b4037d7a9c6e8adbbdfbaa2908c5346e0489be73 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -637,11 +637,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { assert(anotherEnv.address.port != env.address.port) } - test("send with authentication") { - val conf = new SparkConf - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - + private def testSend(conf: SparkConf): Unit = { val localEnv = createRpcEnv(conf, "authentication-local", 0) val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode = true) @@ -667,11 +663,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } - test("ask with authentication") { - val conf = new SparkConf - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - + private def testAsk(conf: SparkConf): Unit = { val localEnv = createRpcEnv(conf, "authentication-local", 0) val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode = true) @@ -695,6 +687,48 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } + test("send with authentication") { + testSend(new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good")) + } + + test("send with SASL encryption") { + testSend(new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.authenticate.enableSaslEncryption", "true")) + } + + test("send with AES encryption") { + testSend(new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.network.crypto.enabled", "true") + .set("spark.network.crypto.saslFallback", "false")) + } + + test("ask with authentication") { + testAsk(new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good")) + } + + test("ask with SASL encryption") { + testAsk(new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.authenticate.enableSaslEncryption", "true")) + } + + test("ask with AES encryption") { + testAsk(new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.network.crypto.enabled", "true") + .set("spark.network.crypto.saslFallback", "false")) + } + test("construct RpcTimeout with conf property") { val conf = new SparkConf diff --git a/docs/configuration.md b/docs/configuration.md index b7f10e69f38e419a6f0f4d1c854f90afb3706ed4..7c040330db637115c97c103759eedb5d21009243 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1639,40 +1639,40 @@ Apart from these, the following properties are also available, and may be useful </td> </tr> <tr> - <td><code>spark.authenticate.enableSaslEncryption</code></td> + <td><code>spark.network.crypto.enabled</code></td> <td>false</td> <td> - Enable encrypted communication when authentication is - enabled. This is supported by the block transfer service and the - RPC endpoints. + Enable encryption using the commons-crypto library for RPC and block transfer service. + Requires <code>spark.authenticate</code> to be enabled. </td> </tr> <tr> - <td><code>spark.network.sasl.serverAlwaysEncrypt</code></td> - <td>false</td> + <td><code>spark.network.crypto.keyLength</code></td> + <td>128</td> <td> - Disable unencrypted connections for services that support SASL authentication. This is - currently supported by the external shuffle service. + The length in bits of the encryption key to generate. Valid values are 128, 192 and 256. </td> </tr> <tr> - <td><code>spark.network.aes.enabled</code></td> - <td>false</td> + <td><code>spark.network.crypto.keyFactoryAlgorithm</code></td> + <td>PBKDF2WithHmacSHA1</td> <td> - Enable AES for over-the-wire encryption. This is supported for RPC and the block transfer service. - This option has precedence over SASL-based encryption if both are enabled. + The key factory algorithm to use when generating encryption keys. Should be one of the + algorithms supported by the javax.crypto.SecretKeyFactory class in the JRE being used. </td> </tr> <tr> - <td><code>spark.network.aes.keySize</code></td> - <td>16</td> + <td><code>spark.network.crypto.saslFallback</code></td> + <td>true</td> <td> - The bytes of AES cipher key which is effective when AES cipher is enabled. AES - works with 16, 24 and 32 bytes keys. + Whether to fall back to SASL authentication if authentication fails using Spark's internal + mechanism. This is useful when the application is connecting to old shuffle services that + do not support the internal Spark authentication protocol. On the server side, this can be + used to block older clients from authenticating against a new shuffle service. </td> </tr> <tr> - <td><code>spark.network.aes.config.*</code></td> + <td><code>spark.network.crypto.config.*</code></td> <td>None</td> <td> Configuration values for the commons-crypto library, such as which cipher implementations to @@ -1680,6 +1680,22 @@ Apart from these, the following properties are also available, and may be useful "commons.crypto" prefix. </td> </tr> +<tr> + <td><code>spark.authenticate.enableSaslEncryption</code></td> + <td>false</td> + <td> + Enable encrypted communication when authentication is + enabled. This is supported by the block transfer service and the + RPC endpoints. + </td> +</tr> +<tr> + <td><code>spark.network.sasl.serverAlwaysEncrypt</code></td> + <td>false</td> + <td> + Disable unencrypted connections for services that support SASL authentication. + </td> +</tr> <tr> <td><code>spark.core.connection.ack.wait.timeout</code></td> <td><code>spark.network.timeout</code></td> diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 3258b09c06278da8f7902c3bb26b06d0a6da903d..f555072c3842ab06975abdc1c3cfa23a75c60cc1 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -136,8 +136,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( new MesosExternalShuffleClient( SparkTransportConf.fromSparkConf(conf, "shuffle"), securityManager, - securityManager.isAuthenticationEnabled(), - securityManager.isSaslEncryptionEnabled()) + securityManager.isAuthenticationEnabled()) } var nextMesosTaskId = 0