diff --git a/launcher/src/main/java/org/apache/spark/launcher/FilteredObjectInputStream.java b/launcher/src/main/java/org/apache/spark/launcher/FilteredObjectInputStream.java
new file mode 100644
index 0000000000000000000000000000000000000000..4d254a0c4c9fe52b5cc3808b76655b50bf910732
--- /dev/null
+++ b/launcher/src/main/java/org/apache/spark/launcher/FilteredObjectInputStream.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.launcher;
+
+import java.io.InputStream;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectStreamClass;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * An object input stream that only allows classes used by the launcher protocol to be in the
+ * serialized stream. See SPARK-20922.
+ */
+class FilteredObjectInputStream extends ObjectInputStream {
+
+  private static final List<String> ALLOWED_PACKAGES = Arrays.asList(
+    "org.apache.spark.launcher.",
+    "java.lang.");
+
+  FilteredObjectInputStream(InputStream is) throws IOException {
+    super(is);
+  }
+
+  @Override
+  protected Class<?> resolveClass(ObjectStreamClass desc)
+      throws IOException, ClassNotFoundException {
+
+    boolean isValid = ALLOWED_PACKAGES.stream().anyMatch(p -> desc.getName().startsWith(p));
+    if (!isValid) {
+      throw new IllegalArgumentException(
+        String.format("Unexpected class in stream: %s", desc.getName()));
+    }
+    return super.resolveClass(desc);
+  }
+
+}
diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java
index eec264909bbb62fdbe4bd273c0cabf6f9c0099b4..b4a8719e260538753afe727d60819c2f8c3be665 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherConnection.java
@@ -20,7 +20,6 @@ package org.apache.spark.launcher;
 import java.io.Closeable;
 import java.io.EOFException;
 import java.io.IOException;
-import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
 import java.net.Socket;
 import java.util.logging.Level;
@@ -53,7 +52,7 @@ abstract class LauncherConnection implements Closeable, Runnable {
   @Override
   public void run() {
     try {
-      ObjectInputStream in = new ObjectInputStream(socket.getInputStream());
+      FilteredObjectInputStream in = new FilteredObjectInputStream(socket.getInputStream());
       while (!closed) {
         Message msg = (Message) in.readObject();
         handle(msg);
diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
index 12f1a0ce2d1b4b4127f0f2480d68c0c77aafcd61..03c2934e2692ed6ecb33f7a7d8717eb433db9da8 100644
--- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
+++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
@@ -19,8 +19,11 @@ package org.apache.spark.launcher;
 
 import java.io.Closeable;
 import java.io.IOException;
+import java.io.ObjectInputStream;
 import java.net.InetAddress;
 import java.net.Socket;
+import java.util.Arrays;
+import java.util.List;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.Semaphore;
@@ -120,31 +123,7 @@ public class LauncherServerSuite extends BaseSuite {
       Socket s = new Socket(InetAddress.getLoopbackAddress(),
         LauncherServer.getServerInstance().getPort());
       client = new TestClient(s);
-
-      // Try a few times since the client-side socket may not reflect the server-side close
-      // immediately.
-      boolean helloSent = false;
-      int maxTries = 10;
-      for (int i = 0; i < maxTries; i++) {
-        try {
-          if (!helloSent) {
-            client.send(new Hello(handle.getSecret(), "1.4.0"));
-            helloSent = true;
-          } else {
-            client.send(new SetAppId("appId"));
-          }
-          fail("Expected exception caused by connection timeout.");
-        } catch (IllegalStateException | IOException e) {
-          // Expected.
-          break;
-        } catch (AssertionError e) {
-          if (i < maxTries - 1) {
-            Thread.sleep(100);
-          } else {
-            throw new AssertionError("Test failed after " + maxTries + " attempts.", e);
-          }
-        }
-      }
+      waitForError(client, handle.getSecret());
     } finally {
       SparkLauncher.launcherConfig.remove(SparkLauncher.CHILD_CONNECTION_TIMEOUT);
       kill(handle);
@@ -183,6 +162,25 @@ public class LauncherServerSuite extends BaseSuite {
     }
   }
 
+  @Test
+  public void testStreamFiltering() throws Exception {
+    ChildProcAppHandle handle = LauncherServer.newAppHandle();
+    TestClient client = null;
+    try {
+      Socket s = new Socket(InetAddress.getLoopbackAddress(),
+        LauncherServer.getServerInstance().getPort());
+
+      client = new TestClient(s);
+      client.send(new EvilPayload());
+      waitForError(client, handle.getSecret());
+      assertEquals(0, EvilPayload.EVIL_BIT);
+    } finally {
+      kill(handle);
+      close(client);
+      client.clientThread.join();
+    }
+  }
+
   private void kill(SparkAppHandle handle) {
     if (handle != null) {
       handle.kill();
@@ -199,6 +197,35 @@ public class LauncherServerSuite extends BaseSuite {
     }
   }
 
+  /**
+   * Try a few times to get a client-side error, since the client-side socket may not reflect the
+   * server-side close immediately.
+   */
+  private void waitForError(TestClient client, String secret) throws Exception {
+    boolean helloSent = false;
+    int maxTries = 10;
+    for (int i = 0; i < maxTries; i++) {
+      try {
+        if (!helloSent) {
+          client.send(new Hello(secret, "1.4.0"));
+          helloSent = true;
+        } else {
+          client.send(new SetAppId("appId"));
+        }
+        fail("Expected error but message went through.");
+      } catch (IllegalStateException | IOException e) {
+        // Expected.
+        break;
+      } catch (AssertionError e) {
+        if (i < maxTries - 1) {
+          Thread.sleep(100);
+        } else {
+          throw new AssertionError("Test failed after " + maxTries + " attempts.", e);
+        }
+      }
+    }
+  }
+
   private static class TestClient extends LauncherConnection {
 
     final BlockingQueue<Message> inbound;
@@ -220,4 +247,19 @@ public class LauncherServerSuite extends BaseSuite {
 
   }
 
+  private static class EvilPayload extends LauncherProtocol.Message {
+
+    static int EVIL_BIT = 0;
+
+    // This field should cause the launcher server to throw an error and not deserialize the
+    // message.
+    private List<String> notAllowedField = Arrays.asList("disallowed");
+
+    private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
+      stream.defaultReadObject();
+      EVIL_BIT = 1;
+    }
+
+  }
+
 }