Skip to content
Snippets Groups Projects
Commit 300807c6 authored by Marcelo Vanzin's avatar Marcelo Vanzin
Browse files

[SPARK-21494][NETWORK] Use correct app id when authenticating to external service.

There was some code based on the old SASL handler in the new auth client that
was incorrectly using the SASL user as the user to authenticate against the
external shuffle service. This caused the external service to not be able to
find the correct secret to authenticate the connection, failing the connection.

In the course of debugging, I found that some log messages from the YARN shuffle
service were a little noisy, so I silenced some of them, and also added a couple
of new ones that helped find this issue. On top of that, I found that a check
in the code that records app secrets was wrong, causing more log spam and also
using an O(n) operation instead of an O(1) call.

Also added a new integration suite for the YARN shuffle service with auth on,
and verified it failed before, and passes now.

Author: Marcelo Vanzin <vanzin@cloudera.com>

Closes #18706 from vanzin/SPARK-21494.
parent ebc24a9b
No related branches found
No related tags found
No related merge requests found
......@@ -50,7 +50,6 @@ public class AuthClientBootstrap implements TransportClientBootstrap {
private final TransportConf conf;
private final String appId;
private final String authUser;
private final SecretKeyHolder secretKeyHolder;
public AuthClientBootstrap(
......@@ -65,7 +64,6 @@ public class AuthClientBootstrap implements TransportClientBootstrap {
// required by the protocol. At some point, though, it would be better for the actual app ID
// to be provided here.
this.appId = appId;
this.authUser = secretKeyHolder.getSaslUser(appId);
this.secretKeyHolder = secretKeyHolder;
}
......@@ -97,8 +95,8 @@ public class AuthClientBootstrap implements TransportClientBootstrap {
private void doSparkAuth(TransportClient client, Channel channel)
throws GeneralSecurityException, IOException {
String secretKey = secretKeyHolder.getSecretKey(authUser);
try (AuthEngine engine = new AuthEngine(authUser, secretKey, conf)) {
String secretKey = secretKeyHolder.getSecretKey(appId);
try (AuthEngine engine = new AuthEngine(appId, secretKey, conf)) {
ClientChallenge challenge = engine.challenge();
ByteBuf challengeData = Unpooled.buffer(challenge.encodedLength());
challenge.encode(challengeData);
......
......@@ -20,6 +20,7 @@ package org.apache.spark.network.crypto;
import java.nio.ByteBuffer;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
......@@ -113,7 +114,11 @@ class AuthRpcHandler extends RpcHandler {
// Here we have the client challenge, so perform the new auth protocol and set up the channel.
AuthEngine engine = null;
try {
engine = new AuthEngine(challenge.appId, secretKeyHolder.getSecretKey(challenge.appId), conf);
String secret = secretKeyHolder.getSecretKey(challenge.appId);
Preconditions.checkState(secret != null,
"Trying to authenticate non-registered app %s.", challenge.appId);
LOG.debug("Authenticating challenge for app {}.", challenge.appId);
engine = new AuthEngine(challenge.appId, secret, conf);
ServerResponse response = engine.respond(challenge);
ByteBuf responseData = Unpooled.buffer(response.encodedLength());
response.encode(responseData);
......
......@@ -47,7 +47,7 @@ public class ShuffleSecretManager implements SecretKeyHolder {
* fetching shuffle files written by other executors in this application.
*/
public void registerApp(String appId, String shuffleSecret) {
if (!shuffleSecretMap.contains(appId)) {
if (!shuffleSecretMap.containsKey(appId)) {
shuffleSecretMap.put(appId, shuffleSecret);
logger.info("Registered shuffle secret for application {}", appId);
} else {
......@@ -67,7 +67,7 @@ public class ShuffleSecretManager implements SecretKeyHolder {
* This is called when the application terminates.
*/
public void unregisterApp(String appId) {
if (shuffleSecretMap.contains(appId)) {
if (shuffleSecretMap.containsKey(appId)) {
shuffleSecretMap.remove(appId);
logger.info("Unregistered shuffle secret for application {}", appId);
} else {
......
......@@ -243,7 +243,6 @@ public class YarnShuffleService extends AuxiliaryService {
String appId = context.getApplicationId().toString();
try {
ByteBuffer shuffleSecret = context.getApplicationDataForService();
logger.info("Initializing application {}", appId);
if (isAuthenticationEnabled()) {
AppId fullId = new AppId(appId);
if (db != null) {
......@@ -262,7 +261,6 @@ public class YarnShuffleService extends AuxiliaryService {
public void stopApplication(ApplicationTerminationContext context) {
String appId = context.getApplicationId().toString();
try {
logger.info("Stopping application {}", appId);
if (isAuthenticationEnabled()) {
AppId fullId = new AppId(appId);
if (db != null) {
......
......@@ -26,7 +26,9 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.scalatest.Matchers
import org.apache.spark._
import org.apache.spark.deploy.yarn.config._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.network.shuffle.ShuffleTestAccessor
import org.apache.spark.network.yarn.{YarnShuffleService, YarnTestAccessor}
import org.apache.spark.tags.ExtendedYarnTest
......@@ -46,28 +48,58 @@ class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite {
yarnConfig
}
protected def extraSparkConf(): Map[String, String] = {
val shuffleServicePort = YarnTestAccessor.getShuffleServicePort
val shuffleService = YarnTestAccessor.getShuffleServiceInstance
logInfo("Shuffle service port = " + shuffleServicePort)
Map(
"spark.shuffle.service.enabled" -> "true",
"spark.shuffle.service.port" -> shuffleServicePort.toString,
MAX_EXECUTOR_FAILURES.key -> "1"
)
}
test("external shuffle service") {
val shuffleServicePort = YarnTestAccessor.getShuffleServicePort
val shuffleService = YarnTestAccessor.getShuffleServiceInstance
val registeredExecFile = YarnTestAccessor.getRegisteredExecutorFile(shuffleService)
logInfo("Shuffle service port = " + shuffleServicePort)
val result = File.createTempFile("result", null, tempDir)
val finalState = runSpark(
false,
mainClassName(YarnExternalShuffleDriver.getClass),
appArgs = Seq(result.getAbsolutePath(), registeredExecFile.getAbsolutePath),
extraConf = Map(
"spark.shuffle.service.enabled" -> "true",
"spark.shuffle.service.port" -> shuffleServicePort.toString
)
extraConf = extraSparkConf()
)
checkResult(finalState, result)
assert(YarnTestAccessor.getRegisteredExecutorFile(shuffleService).exists())
}
}
/**
* Integration test for the external shuffle service with auth on.
*/
@ExtendedYarnTest
class YarnShuffleAuthSuite extends YarnShuffleIntegrationSuite {
override def newYarnConfig(): YarnConfiguration = {
val yarnConfig = super.newYarnConfig()
yarnConfig.set(NETWORK_AUTH_ENABLED.key, "true")
yarnConfig.set(NETWORK_ENCRYPTION_ENABLED.key, "true")
yarnConfig
}
override protected def extraSparkConf(): Map[String, String] = {
super.extraSparkConf() ++ Map(
NETWORK_AUTH_ENABLED.key -> "true",
NETWORK_ENCRYPTION_ENABLED.key -> "true"
)
}
}
private object YarnExternalShuffleDriver extends Logging with Matchers {
val WAIT_TIMEOUT_MILLIS = 10000
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment