diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java index 799f4540aa934f7a1d6a7f6fc7551faaf2da97b7..3c263783a6104b51da44b5651e32acd2a772c7a1 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java @@ -50,7 +50,6 @@ public class AuthClientBootstrap implements TransportClientBootstrap { private final TransportConf conf; private final String appId; - private final String authUser; private final SecretKeyHolder secretKeyHolder; public AuthClientBootstrap( @@ -65,7 +64,6 @@ public class AuthClientBootstrap implements TransportClientBootstrap { // required by the protocol. At some point, though, it would be better for the actual app ID // to be provided here. this.appId = appId; - this.authUser = secretKeyHolder.getSaslUser(appId); this.secretKeyHolder = secretKeyHolder; } @@ -97,8 +95,8 @@ public class AuthClientBootstrap implements TransportClientBootstrap { private void doSparkAuth(TransportClient client, Channel channel) throws GeneralSecurityException, IOException { - String secretKey = secretKeyHolder.getSecretKey(authUser); - try (AuthEngine engine = new AuthEngine(authUser, secretKey, conf)) { + String secretKey = secretKeyHolder.getSecretKey(appId); + try (AuthEngine engine = new AuthEngine(appId, secretKey, conf)) { ClientChallenge challenge = engine.challenge(); ByteBuf challengeData = Unpooled.buffer(challenge.encodedLength()); challenge.encode(challengeData); diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java index 0a5c029940005489899e71f1f0549ae6fb999804..8a6e3858081bfb1610865176b654022531b8361d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -20,6 +20,7 @@ package org.apache.spark.network.crypto; import java.nio.ByteBuffer; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; @@ -113,7 +114,11 @@ class AuthRpcHandler extends RpcHandler { // Here we have the client challenge, so perform the new auth protocol and set up the channel. AuthEngine engine = null; try { - engine = new AuthEngine(challenge.appId, secretKeyHolder.getSecretKey(challenge.appId), conf); + String secret = secretKeyHolder.getSecretKey(challenge.appId); + Preconditions.checkState(secret != null, + "Trying to authenticate non-registered app %s.", challenge.appId); + LOG.debug("Authenticating challenge for app {}.", challenge.appId); + engine = new AuthEngine(challenge.appId, secret, conf); ServerResponse response = engine.respond(challenge); ByteBuf responseData = Unpooled.buffer(response.encodedLength()); response.encode(responseData); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java index 426a604f4f1575570f01fad5427ec15a4cce17a8..d2d008f8a3d355040d9d22b143c03ad1f74bef14 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java @@ -47,7 +47,7 @@ public class ShuffleSecretManager implements SecretKeyHolder { * fetching shuffle files written by other executors in this application. */ public void registerApp(String appId, String shuffleSecret) { - if (!shuffleSecretMap.contains(appId)) { + if (!shuffleSecretMap.containsKey(appId)) { shuffleSecretMap.put(appId, shuffleSecret); logger.info("Registered shuffle secret for application {}", appId); } else { @@ -67,7 +67,7 @@ public class ShuffleSecretManager implements SecretKeyHolder { * This is called when the application terminates. */ public void unregisterApp(String appId) { - if (shuffleSecretMap.contains(appId)) { + if (shuffleSecretMap.containsKey(appId)) { shuffleSecretMap.remove(appId); logger.info("Unregistered shuffle secret for application {}", appId); } else { diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index fd50e3a4bfb9bb34c1dfbc15f090aeeebb17a58a..cd67eb28573e836141665041ed94766e9dd44e14 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -243,7 +243,6 @@ public class YarnShuffleService extends AuxiliaryService { String appId = context.getApplicationId().toString(); try { ByteBuffer shuffleSecret = context.getApplicationDataForService(); - logger.info("Initializing application {}", appId); if (isAuthenticationEnabled()) { AppId fullId = new AppId(appId); if (db != null) { @@ -262,7 +261,6 @@ public class YarnShuffleService extends AuxiliaryService { public void stopApplication(ApplicationTerminationContext context) { String appId = context.getApplicationId().toString(); try { - logger.info("Stopping application {}", appId); if (isAuthenticationEnabled()) { AppId fullId = new AppId(appId); if (db != null) { diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala index 950ebd9a2d4d9b7bc0c7c42e05fac90d4ba7d08a..75427b4ad6cb4e95109848d1eea86fee199631c7 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -26,7 +26,9 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.Matchers import org.apache.spark._ +import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.network.shuffle.ShuffleTestAccessor import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor} import org.apache.spark.tags.ExtendedYarnTest @@ -46,28 +48,58 @@ class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite { yarnConfig } + protected def extraSparkConf(): Map[String, String] = { + val shuffleServicePort = YarnTestAccessor.getShuffleServicePort + val shuffleService = YarnTestAccessor.getShuffleServiceInstance + logInfo("Shuffle service port = " + shuffleServicePort) + + Map( + "spark.shuffle.service.enabled" -> "true", + "spark.shuffle.service.port" -> shuffleServicePort.toString, + MAX_EXECUTOR_FAILURES.key -> "1" + ) + } + test("external shuffle service") { val shuffleServicePort = YarnTestAccessor.getShuffleServicePort val shuffleService = YarnTestAccessor.getShuffleServiceInstance val registeredExecFile = YarnTestAccessor.getRegisteredExecutorFile(shuffleService) - logInfo("Shuffle service port = " + shuffleServicePort) val result = File.createTempFile("result", null, tempDir) val finalState = runSpark( false, mainClassName(YarnExternalShuffleDriver.getClass), appArgs = Seq(result.getAbsolutePath(), registeredExecFile.getAbsolutePath), - extraConf = Map( - "spark.shuffle.service.enabled" -> "true", - "spark.shuffle.service.port" -> shuffleServicePort.toString - ) + extraConf = extraSparkConf() ) checkResult(finalState, result) assert(YarnTestAccessor.getRegisteredExecutorFile(shuffleService).exists()) } } +/** + * Integration test for the external shuffle service with auth on. + */ +@ExtendedYarnTest +class YarnShuffleAuthSuite extends YarnShuffleIntegrationSuite { + + override def newYarnConfig(): YarnConfiguration = { + val yarnConfig = super.newYarnConfig() + yarnConfig.set(NETWORK_AUTH_ENABLED.key, "true") + yarnConfig.set(NETWORK_ENCRYPTION_ENABLED.key, "true") + yarnConfig + } + + override protected def extraSparkConf(): Map[String, String] = { + super.extraSparkConf() ++ Map( + NETWORK_AUTH_ENABLED.key -> "true", + NETWORK_ENCRYPTION_ENABLED.key -> "true" + ) + } + +} + private object YarnExternalShuffleDriver extends Logging with Matchers { val WAIT_TIMEOUT_MILLIS = 10000