Skip to content
Snippets Groups Projects
Commit 8efc6e98 authored by Marcelo Vanzin's avatar Marcelo Vanzin
Browse files

[SPARK-20922][CORE] Add whitelist of classes that can be deserialized by the launcher.

Blindly deserializing classes using Java serialization opens the code up to
issues in other libraries, since just deserializing data from a stream may
end up execution code (think readObject()).

Since the launcher protocol is pretty self-contained, there's just a handful
of classes it legitimately needs to deserialize, and they're in just two
packages, so add a filter that throws errors if classes from any other
package show up in the stream.

This also maintains backwards compatibility (the updated launcher code can
still communicate with the backend code in older Spark releases).

Tested with new and existing unit tests.

Author: Marcelo Vanzin <vanzin@cloudera.com>

Closes #18166 from vanzin/SPARK-20922.
parent 640afa49
No related branches found
No related tags found
No related merge requests found
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.launcher;
import java.io.InputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectStreamClass;
import java.util.Arrays;
import java.util.List;
/**
* An object input stream that only allows classes used by the launcher protocol to be in the
* serialized stream. See SPARK-20922.
*/
class FilteredObjectInputStream extends ObjectInputStream {
private static final List<String> ALLOWED_PACKAGES = Arrays.asList(
"org.apache.spark.launcher.",
"java.lang.");
FilteredObjectInputStream(InputStream is) throws IOException {
super(is);
}
@Override
protected Class<?> resolveClass(ObjectStreamClass desc)
throws IOException, ClassNotFoundException {
boolean isValid = ALLOWED_PACKAGES.stream().anyMatch(p -> desc.getName().startsWith(p));
if (!isValid) {
throw new IllegalArgumentException(
String.format("Unexpected class in stream: %s", desc.getName()));
}
return super.resolveClass(desc);
}
}
...@@ -20,7 +20,6 @@ package org.apache.spark.launcher; ...@@ -20,7 +20,6 @@ package org.apache.spark.launcher;
import java.io.Closeable; import java.io.Closeable;
import java.io.EOFException; import java.io.EOFException;
import java.io.IOException; import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream; import java.io.ObjectOutputStream;
import java.net.Socket; import java.net.Socket;
import java.util.logging.Level; import java.util.logging.Level;
...@@ -53,7 +52,7 @@ abstract class LauncherConnection implements Closeable, Runnable { ...@@ -53,7 +52,7 @@ abstract class LauncherConnection implements Closeable, Runnable {
@Override @Override
public void run() { public void run() {
try { try {
ObjectInputStream in = new ObjectInputStream(socket.getInputStream()); FilteredObjectInputStream in = new FilteredObjectInputStream(socket.getInputStream());
while (!closed) { while (!closed) {
Message msg = (Message) in.readObject(); Message msg = (Message) in.readObject();
handle(msg); handle(msg);
......
...@@ -19,8 +19,11 @@ package org.apache.spark.launcher; ...@@ -19,8 +19,11 @@ package org.apache.spark.launcher;
import java.io.Closeable; import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
import java.io.ObjectInputStream;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.Socket; import java.net.Socket;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.BlockingQueue; import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.Semaphore; import java.util.concurrent.Semaphore;
...@@ -120,31 +123,7 @@ public class LauncherServerSuite extends BaseSuite { ...@@ -120,31 +123,7 @@ public class LauncherServerSuite extends BaseSuite {
Socket s = new Socket(InetAddress.getLoopbackAddress(), Socket s = new Socket(InetAddress.getLoopbackAddress(),
LauncherServer.getServerInstance().getPort()); LauncherServer.getServerInstance().getPort());
client = new TestClient(s); client = new TestClient(s);
waitForError(client, handle.getSecret());
// Try a few times since the client-side socket may not reflect the server-side close
// immediately.
boolean helloSent = false;
int maxTries = 10;
for (int i = 0; i < maxTries; i++) {
try {
if (!helloSent) {
client.send(new Hello(handle.getSecret(), "1.4.0"));
helloSent = true;
} else {
client.send(new SetAppId("appId"));
}
fail("Expected exception caused by connection timeout.");
} catch (IllegalStateException | IOException e) {
// Expected.
break;
} catch (AssertionError e) {
if (i < maxTries - 1) {
Thread.sleep(100);
} else {
throw new AssertionError("Test failed after " + maxTries + " attempts.", e);
}
}
}
} finally { } finally {
SparkLauncher.launcherConfig.remove(SparkLauncher.CHILD_CONNECTION_TIMEOUT); SparkLauncher.launcherConfig.remove(SparkLauncher.CHILD_CONNECTION_TIMEOUT);
kill(handle); kill(handle);
...@@ -183,6 +162,25 @@ public class LauncherServerSuite extends BaseSuite { ...@@ -183,6 +162,25 @@ public class LauncherServerSuite extends BaseSuite {
} }
} }
@Test
public void testStreamFiltering() throws Exception {
ChildProcAppHandle handle = LauncherServer.newAppHandle();
TestClient client = null;
try {
Socket s = new Socket(InetAddress.getLoopbackAddress(),
LauncherServer.getServerInstance().getPort());
client = new TestClient(s);
client.send(new EvilPayload());
waitForError(client, handle.getSecret());
assertEquals(0, EvilPayload.EVIL_BIT);
} finally {
kill(handle);
close(client);
client.clientThread.join();
}
}
private void kill(SparkAppHandle handle) { private void kill(SparkAppHandle handle) {
if (handle != null) { if (handle != null) {
handle.kill(); handle.kill();
...@@ -199,6 +197,35 @@ public class LauncherServerSuite extends BaseSuite { ...@@ -199,6 +197,35 @@ public class LauncherServerSuite extends BaseSuite {
} }
} }
/**
* Try a few times to get a client-side error, since the client-side socket may not reflect the
* server-side close immediately.
*/
private void waitForError(TestClient client, String secret) throws Exception {
boolean helloSent = false;
int maxTries = 10;
for (int i = 0; i < maxTries; i++) {
try {
if (!helloSent) {
client.send(new Hello(secret, "1.4.0"));
helloSent = true;
} else {
client.send(new SetAppId("appId"));
}
fail("Expected error but message went through.");
} catch (IllegalStateException | IOException e) {
// Expected.
break;
} catch (AssertionError e) {
if (i < maxTries - 1) {
Thread.sleep(100);
} else {
throw new AssertionError("Test failed after " + maxTries + " attempts.", e);
}
}
}
}
private static class TestClient extends LauncherConnection { private static class TestClient extends LauncherConnection {
final BlockingQueue<Message> inbound; final BlockingQueue<Message> inbound;
...@@ -220,4 +247,19 @@ public class LauncherServerSuite extends BaseSuite { ...@@ -220,4 +247,19 @@ public class LauncherServerSuite extends BaseSuite {
} }
private static class EvilPayload extends LauncherProtocol.Message {
static int EVIL_BIT = 0;
// This field should cause the launcher server to throw an error and not deserialize the
// message.
private List<String> notAllowedField = Arrays.asList("disallowed");
private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
stream.defaultReadObject();
EVIL_BIT = 1;
}
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment