From 8efc6e986554ae66eab93cd64a9035d716adbab0 Mon Sep 17 00:00:00 2001
From: Marcelo Vanzin <vanzin@cloudera.com>
Date: Thu, 1 Jun 2017 14:44:34 -0700
Subject: [PATCH] [SPARK-20922][CORE] Add whitelist of classes that can be
 deserialized by the launcher.

Blindly deserializing classes using Java serialization opens the code up to
issues in other libraries, since just deserializing data from a stream may
end up execution code (think readObject()).

Since the launcher protocol is pretty self-contained, there's just a handful
of classes it legitimately needs to deserialize, and they're in just two
packages, so add a filter that throws errors if classes from any other
package show up in the stream.

This also maintains backwards compatibility (the updated launcher code can
still communicate with the backend code in older Spark releases).

Tested with new and existing unit tests.

Author: Marcelo Vanzin <vanzin@cloudera.com>

Closes #18166 from vanzin/SPARK-20922.
---
 .../launcher/FilteredObjectInputStream.java   | 53 +++++++++++
 .../spark/launcher/LauncherConnection.java    |  3 +-
 .../spark/launcher/LauncherServerSuite.java   | 92 ++++++++++++++-----
 3 files changed, 121 insertions(+), 27 deletions(-)
 create mode 100644 launcher/src/main/java/org/apache/spark/launcher/FilteredObjectInputStream.java

diff --git a/launcher/src/main/java/org/apache/spark/launcher/FilteredObjectInputStream.java b/launcher/src/main/java/org/apache/spark/launcher/FilteredObjectInputStream.java
new file mode 100644
index 0000000000..4d254a0c4c
--- /dev/null
+++ b/launcher/src/main/java/org/apache/spark/launcher/FilteredObjectInputStream.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.launcher;
+
+import java.io.InputStream;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectStreamClass;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * An object input stream that only allows classes used by the launcher protocol to be in the
+ * serialized stream. See SPARK-20922.
+ */
+class FilteredObjectInputStream extends ObjectInputStream {
+
+  private static final List<String> ALLOWED_PACKAGES = Arrays.asList(
+    "org.apache.spark.launcher.",
+    "java.lang.");
+
+  FilteredObjectInputStream(InputStream is) throws IOException {
+    super(is);
+  }
+
+  @Override
+  protected Class<?> resolveClass(ObjectStreamClass desc)
+      throws IOException, ClassNotFoundException {
+
+    boolean isValid = ALLOWED_PACKAGES.stream().anyMatch(p -> desc.getName().startsWith(p));
+    if (!isValid) {
+      throw new IllegalArgumentException(
+        String.format("Unexpected class in stream: %s", desc.getName()));
+    }
+    return super.resolveClass(desc);
+  }
+
+}
diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java
index eec264909b..b4a8719e26 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java
@@ -20,7 +20,6 @@ package org.apache.spark.launcher;
 import java.io.Closeable;
 import java.io.EOFException;
 import java.io.IOException;
-import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
 import java.net.Socket;
 import java.util.logging.Level;
@@ -53,7 +52,7 @@ abstract class LauncherConnection implements Closeable, Runnable {
   @Override
   public void run() {
     try {
-      ObjectInputStream in = new ObjectInputStream(socket.getInputStream());
+      FilteredObjectInputStream in = new FilteredObjectInputStream(socket.getInputStream());
       while (!closed) {
         Message msg = (Message) in.readObject();
         handle(msg);
diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
index 12f1a0ce2d..03c2934e26 100644
--- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
+++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
@@ -19,8 +19,11 @@ package org.apache.spark.launcher;
 
 import java.io.Closeable;
 import java.io.IOException;
+import java.io.ObjectInputStream;
 import java.net.InetAddress;
 import java.net.Socket;
+import java.util.Arrays;
+import java.util.List;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.Semaphore;
@@ -120,31 +123,7 @@ public class LauncherServerSuite extends BaseSuite {
       Socket s = new Socket(InetAddress.getLoopbackAddress(),
         LauncherServer.getServerInstance().getPort());
       client = new TestClient(s);
-
-      // Try a few times since the client-side socket may not reflect the server-side close
-      // immediately.
-      boolean helloSent = false;
-      int maxTries = 10;
-      for (int i = 0; i < maxTries; i++) {
-        try {
-          if (!helloSent) {
-            client.send(new Hello(handle.getSecret(), "1.4.0"));
-            helloSent = true;
-          } else {
-            client.send(new SetAppId("appId"));
-          }
-          fail("Expected exception caused by connection timeout.");
-        } catch (IllegalStateException | IOException e) {
-          // Expected.
-          break;
-        } catch (AssertionError e) {
-          if (i < maxTries - 1) {
-            Thread.sleep(100);
-          } else {
-            throw new AssertionError("Test failed after " + maxTries + " attempts.", e);
-          }
-        }
-      }
+      waitForError(client, handle.getSecret());
     } finally {
       SparkLauncher.launcherConfig.remove(SparkLauncher.CHILD_CONNECTION_TIMEOUT);
       kill(handle);
@@ -183,6 +162,25 @@ public class LauncherServerSuite extends BaseSuite {
     }
   }
 
+  @Test
+  public void testStreamFiltering() throws Exception {
+    ChildProcAppHandle handle = LauncherServer.newAppHandle();
+    TestClient client = null;
+    try {
+      Socket s = new Socket(InetAddress.getLoopbackAddress(),
+        LauncherServer.getServerInstance().getPort());
+
+      client = new TestClient(s);
+      client.send(new EvilPayload());
+      waitForError(client, handle.getSecret());
+      assertEquals(0, EvilPayload.EVIL_BIT);
+    } finally {
+      kill(handle);
+      close(client);
+      client.clientThread.join();
+    }
+  }
+
   private void kill(SparkAppHandle handle) {
     if (handle != null) {
       handle.kill();
@@ -199,6 +197,35 @@ public class LauncherServerSuite extends BaseSuite {
     }
   }
 
+  /**
+   * Try a few times to get a client-side error, since the client-side socket may not reflect the
+   * server-side close immediately.
+   */
+  private void waitForError(TestClient client, String secret) throws Exception {
+    boolean helloSent = false;
+    int maxTries = 10;
+    for (int i = 0; i < maxTries; i++) {
+      try {
+        if (!helloSent) {
+          client.send(new Hello(secret, "1.4.0"));
+          helloSent = true;
+        } else {
+          client.send(new SetAppId("appId"));
+        }
+        fail("Expected error but message went through.");
+      } catch (IllegalStateException | IOException e) {
+        // Expected.
+        break;
+      } catch (AssertionError e) {
+        if (i < maxTries - 1) {
+          Thread.sleep(100);
+        } else {
+          throw new AssertionError("Test failed after " + maxTries + " attempts.", e);
+        }
+      }
+    }
+  }
+
   private static class TestClient extends LauncherConnection {
 
     final BlockingQueue<Message> inbound;
@@ -220,4 +247,19 @@ public class LauncherServerSuite extends BaseSuite {
 
   }
 
+  private static class EvilPayload extends LauncherProtocol.Message {
+
+    static int EVIL_BIT = 0;
+
+    // This field should cause the launcher server to throw an error and not deserialize the
+    // message.
+    private List<String> notAllowedField = Arrays.asList("disallowed");
+
+    private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
+      stream.defaultReadObject();
+      EVIL_BIT = 1;
+    }
+
+  }
+
 }
-- 
GitLab