From e4065376d2b4eec178a119476fa95b26f440c076 Mon Sep 17 00:00:00 2001
From: Adam Budde <budde@amazon.com>
Date: Wed, 22 Feb 2017 11:32:36 -0500
Subject: [PATCH] [SPARK-19405][STREAMING] Support for cross-account Kinesis
 reads via STS
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

- Add dependency on aws-java-sdk-sts
- Replace SerializableAWSCredentials with new SerializableCredentialsProvider interface
- Make KinesisReceiver take SerializableCredentialsProvider as argument and
  pass credential provider to KCL
- Add new implementations of KinesisUtils.createStream() that take STS
  arguments
- Make JavaKinesisStreamSuite test the entire KinesisUtils Java API
- Update KCL/AWS SDK dependencies to 1.7.x/1.11.x

## What changes were proposed in this pull request?

[JIRA link with detailed description.](https://issues.apache.org/jira/browse/SPARK-19405)

* Replace SerializableAWSCredentials with new SerializableKCLAuthProvider class that takes 5 optional config params for configuring AWS auth and returns the appropriate credential provider object
* Add new public createStream() APIs for specifying these parameters in KinesisUtils

## How was this patch tested?

* Manually tested using explicit keypair and instance profile to read data from Kinesis stream in separate account (difficult to write a test orchestrating creation and assumption of IAM roles across separate accounts)
* Expanded JavaKinesisStreamSuite to test the entire Java API in KinesisUtils

## License acknowledgement
This contribution is my original work and that I license the work to the project under the project’s open source license.

Author: Budde <budde@amazon.com>

Closes #16744 from budde/master.
---
 external/kinesis-asl/pom.xml                  |   5 +
 .../streaming/JavaKinesisWordCountASL.java    |   2 +-
 .../streaming/KinesisExampleUtils.scala       |  35 ++++
 .../streaming/KinesisWordCountASL.scala       |   2 +-
 .../kinesis/KinesisBackedBlockRDD.scala       |   8 +-
 .../kinesis/KinesisCheckpointer.scala         |   2 +-
 .../kinesis/KinesisInputDStream.scala         |   7 +-
 .../streaming/kinesis/KinesisReceiver.scala   |  51 ++---
 .../kinesis/KinesisRecordProcessor.scala      |   2 +-
 .../streaming/kinesis/KinesisTestUtils.scala  |  14 +-
 .../streaming/kinesis/KinesisUtils.scala      | 192 ++++++++++++++++--
 .../SerializableCredentialsProvider.scala     |  85 ++++++++
 .../kinesis/JavaKinesisStreamSuite.java       |  35 +++-
 .../kinesis/KinesisReceiverSuite.scala        |  25 ++-
 .../kinesis/KinesisStreamSuite.scala          |   9 +-
 pom.xml                                       |   4 +-
 python/pyspark/streaming/kinesis.py           |  12 +-
 17 files changed, 407 insertions(+), 83 deletions(-)
 create mode 100644 external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisExampleUtils.scala
 create mode 100644 external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala

diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml
index b2bac7c938..daa79e7916 100644
--- a/external/kinesis-asl/pom.xml
+++ b/external/kinesis-asl/pom.xml
@@ -58,6 +58,11 @@
       <artifactId>amazon-kinesis-client</artifactId>
       <version>${aws.kinesis.client.version}</version>
     </dependency>
+    <dependency>
+      <groupId>com.amazonaws</groupId>
+      <artifactId>aws-java-sdk-sts</artifactId>
+      <version>${aws.java.sdk.version}</version>
+    </dependency>
     <dependency>
       <groupId>com.amazonaws</groupId>
       <artifactId>amazon-kinesis-producer</artifactId>
diff --git a/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java
index d40bd3ff56..d1274a687f 100644
--- a/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java
+++ b/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java
@@ -127,7 +127,7 @@ public final class JavaKinesisWordCountASL { // needs to be public for access fr
 
     // Get the region name from the endpoint URL to save Kinesis Client Library metadata in
     // DynamoDB of the same region as the Kinesis stream
-    String regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName();
+    String regionName = KinesisExampleUtils.getRegionNameByEndpoint(endpointUrl);
 
     // Setup the Spark config and StreamingContext
     SparkConf sparkConfig = new SparkConf().setAppName("JavaKinesisWordCountASL");
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisExampleUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisExampleUtils.scala
new file mode 100644
index 0000000000..2eebd6130d
--- /dev/null
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisExampleUtils.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.streaming
+
+import scala.collection.JavaConverters._
+
+import com.amazonaws.regions.RegionUtils
+import com.amazonaws.services.kinesis.AmazonKinesis
+
+private[streaming] object KinesisExampleUtils {
+  def getRegionNameByEndpoint(endpoint: String): String = {
+    val uri = new java.net.URI(endpoint)
+    RegionUtils.getRegionsForService(AmazonKinesis.ENDPOINT_PREFIX)
+      .asScala
+      .find(_.getAvailableEndpoints.asScala.toSeq.contains(uri.getHost))
+      .map(_.getName)
+      .getOrElse(
+        throw new IllegalArgumentException(s"Could not resolve region for endpoint: $endpoint"))
+  }
+}
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala
index a70c13d7d6..f14117b708 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala
@@ -127,7 +127,7 @@ object KinesisWordCountASL extends Logging {
 
     // Get the region name from the endpoint URL to save Kinesis Client Library metadata in
     // DynamoDB of the same region as the Kinesis stream
-    val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName()
+    val regionName = KinesisExampleUtils.getRegionNameByEndpoint(endpointUrl)
 
     // Setup the SparkConfig and StreamingContext
     val sparkConfig = new SparkConf().setAppName("KinesisWordCountASL")
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala
index 45dc3c388c..23c4d99e50 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala
@@ -79,7 +79,7 @@ class KinesisBackedBlockRDD[T: ClassTag](
     @transient private val isBlockIdValid: Array[Boolean] = Array.empty,
     val retryTimeoutMs: Int = 10000,
     val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _,
-    val awsCredentialsOption: Option[SerializableAWSCredentials] = None
+    val kinesisCredsProvider: SerializableCredentialsProvider = DefaultCredentialsProvider
   ) extends BlockRDD[T](sc, _blockIds) {
 
   require(_blockIds.length == arrayOfseqNumberRanges.length,
@@ -105,9 +105,7 @@ class KinesisBackedBlockRDD[T: ClassTag](
     }
 
     def getBlockFromKinesis(): Iterator[T] = {
-      val credentials = awsCredentialsOption.getOrElse {
-        new DefaultAWSCredentialsProviderChain().getCredentials()
-      }
+      val credentials = kinesisCredsProvider.provider.getCredentials
       partition.seqNumberRanges.ranges.iterator.flatMap { range =>
         new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName,
           range, retryTimeoutMs).map(messageHandler)
@@ -143,7 +141,7 @@ class KinesisSequenceRangeIterator(
   private var lastSeqNumber: String = null
   private var internalIterator: Iterator[Record] = null
 
-  client.setEndpoint(endpointUrl, "kinesis", regionId)
+  client.setEndpoint(endpointUrl)
 
   override protected def getNext(): Record = {
     var nextRecord: Record = null
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala
index c445c15a5f..5fb83b26f8 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala
@@ -21,7 +21,7 @@ import java.util.concurrent._
 import scala.util.control.NonFatal
 
 import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer
-import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason
+import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.streaming.Duration
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala
index 5223c81a8e..fbc6b99443 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala
@@ -39,7 +39,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag](
     checkpointInterval: Duration,
     storageLevel: StorageLevel,
     messageHandler: Record => T,
-    awsCredentialsOption: Option[SerializableAWSCredentials]
+    kinesisCredsProvider: SerializableCredentialsProvider
   ) extends ReceiverInputDStream[T](_ssc) {
 
   private[streaming]
@@ -61,7 +61,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag](
         isBlockIdValid = isBlockIdValid,
         retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt,
         messageHandler = messageHandler,
-        awsCredentialsOption = awsCredentialsOption)
+        kinesisCredsProvider = kinesisCredsProvider)
     } else {
       logWarning("Kinesis sequence number information was not present with some block metadata," +
         " it may not be possible to recover from failures")
@@ -71,6 +71,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag](
 
   override def getReceiver(): Receiver[T] = {
     new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream,
-      checkpointAppName, checkpointInterval, storageLevel, messageHandler, awsCredentialsOption)
+      checkpointAppName, checkpointInterval, storageLevel, messageHandler,
+      kinesisCredsProvider)
   }
 }
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
index 393e56a393..13fc54e531 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
@@ -23,7 +23,6 @@ import scala.collection.JavaConverters._
 import scala.collection.mutable
 import scala.util.control.NonFatal
 
-import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, DefaultAWSCredentialsProviderChain}
 import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer, IRecordProcessorFactory}
 import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker}
 import com.amazonaws.services.kinesis.model.Record
@@ -34,13 +33,6 @@ import org.apache.spark.streaming.Duration
 import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver}
 import org.apache.spark.util.Utils
 
-private[kinesis]
-case class SerializableAWSCredentials(accessKeyId: String, secretKey: String)
-  extends AWSCredentials {
-  override def getAWSAccessKeyId: String = accessKeyId
-  override def getAWSSecretKey: String = secretKey
-}
-
 /**
  * Custom AWS Kinesis-specific implementation of Spark Streaming's Receiver.
  * This implementation relies on the Kinesis Client Library (KCL) Worker as described here:
@@ -78,8 +70,9 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String)
  *                            See the Kinesis Spark Streaming documentation for more
  *                            details on the different types of checkpoints.
  * @param storageLevel Storage level to use for storing the received objects
- * @param awsCredentialsOption Optional AWS credentials, used when user directly specifies
- *                             the credentials
+ * @param kinesisCredsProvider SerializableCredentialsProvider instance that will be used to
+ *                             generate the AWSCredentialsProvider instance used for KCL
+ *                             authorization.
  */
 private[kinesis] class KinesisReceiver[T](
     val streamName: String,
@@ -90,7 +83,7 @@ private[kinesis] class KinesisReceiver[T](
     checkpointInterval: Duration,
     storageLevel: StorageLevel,
     messageHandler: Record => T,
-    awsCredentialsOption: Option[SerializableAWSCredentials])
+    kinesisCredsProvider: SerializableCredentialsProvider)
   extends Receiver[T](storageLevel) with Logging { receiver =>
 
   /*
@@ -147,14 +140,15 @@ private[kinesis] class KinesisReceiver[T](
     workerId = Utils.localHostName() + ":" + UUID.randomUUID()
 
     kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, workerId)
-    // KCL config instance
-    val awsCredProvider = resolveAWSCredentialsProvider()
-    val kinesisClientLibConfiguration =
-      new KinesisClientLibConfiguration(checkpointAppName, streamName, awsCredProvider, workerId)
-      .withKinesisEndpoint(endpointUrl)
-      .withInitialPositionInStream(initialPositionInStream)
-      .withTaskBackoffTimeMillis(500)
-      .withRegionName(regionName)
+    val kinesisClientLibConfiguration = new KinesisClientLibConfiguration(
+          checkpointAppName,
+          streamName,
+          kinesisCredsProvider.provider,
+          workerId)
+        .withKinesisEndpoint(endpointUrl)
+        .withInitialPositionInStream(initialPositionInStream)
+        .withTaskBackoffTimeMillis(500)
+        .withRegionName(regionName)
 
    /*
     *  RecordProcessorFactory creates impls of IRecordProcessor.
@@ -305,25 +299,6 @@ private[kinesis] class KinesisReceiver[T](
     }
   }
 
-  /**
-   * If AWS credential is provided, return a AWSCredentialProvider returning that credential.
-   * Otherwise, return the DefaultAWSCredentialsProviderChain.
-   */
-  private def resolveAWSCredentialsProvider(): AWSCredentialsProvider = {
-    awsCredentialsOption match {
-      case Some(awsCredentials) =>
-        logInfo("Using provided AWS credentials")
-        new AWSCredentialsProvider {
-          override def getCredentials: AWSCredentials = awsCredentials
-          override def refresh(): Unit = { }
-        }
-      case None =>
-        logInfo("Using DefaultAWSCredentialsProviderChain")
-        new DefaultAWSCredentialsProviderChain()
-    }
-  }
-
-
   /**
    * Class to handle blocks generated by this receiver's block generator. Specifically, in
    * the context of the Kinesis Receiver, this handler does the following.
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
index 73ccc4ad23..8c6a399dd7 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
@@ -23,7 +23,7 @@ import scala.util.control.NonFatal
 
 import com.amazonaws.services.kinesis.clientlibrary.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException}
 import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer}
-import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason
+import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason
 import com.amazonaws.services.kinesis.model.Record
 
 import org.apache.spark.internal.Logging
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
index f183ef00b3..73ac7a3cd2 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
@@ -30,7 +30,7 @@ import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain}
 import com.amazonaws.regions.RegionUtils
 import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient
 import com.amazonaws.services.dynamodbv2.document.DynamoDB
-import com.amazonaws.services.kinesis.AmazonKinesisClient
+import com.amazonaws.services.kinesis.{AmazonKinesis, AmazonKinesisClient}
 import com.amazonaws.services.kinesis.model._
 
 import org.apache.spark.internal.Logging
@@ -43,7 +43,7 @@ import org.apache.spark.internal.Logging
 private[kinesis] class KinesisTestUtils(streamShardCount: Int = 2) extends Logging {
 
   val endpointUrl = KinesisTestUtils.endpointUrl
-  val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName()
+  val regionName = KinesisTestUtils.getRegionNameByEndpoint(endpointUrl)
 
   private val createStreamTimeoutSeconds = 300
   private val describeStreamPollTimeSeconds = 1
@@ -205,6 +205,16 @@ private[kinesis] object KinesisTestUtils {
   val endVarNameForEndpoint = "KINESIS_TEST_ENDPOINT_URL"
   val defaultEndpointUrl = "https://kinesis.us-west-2.amazonaws.com"
 
+  def getRegionNameByEndpoint(endpoint: String): String = {
+    val uri = new java.net.URI(endpoint)
+    RegionUtils.getRegionsForService(AmazonKinesis.ENDPOINT_PREFIX)
+      .asScala
+      .find(_.getAvailableEndpoints.asScala.toSeq.contains(uri.getHost))
+      .map(_.getName)
+      .getOrElse(
+        throw new IllegalArgumentException(s"Could not resolve region for endpoint: $endpoint"))
+  }
+
   lazy val shouldRunTests = {
     val isEnvSet = sys.env.get(envVarNameForEnablingTests) == Some("1")
     if (isEnvSet) {
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
index b2daffa34c..2d777982e7 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
@@ -73,7 +73,7 @@ object KinesisUtils {
     ssc.withNamedScope("kinesis stream") {
       new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName),
         initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
-        cleanedHandler, None)
+        cleanedHandler, DefaultCredentialsProvider)
     }
   }
 
@@ -123,9 +123,80 @@ object KinesisUtils {
     // scalastyle:on
     val cleanedHandler = ssc.sc.clean(messageHandler)
     ssc.withNamedScope("kinesis stream") {
+      val kinesisCredsProvider = BasicCredentialsProvider(
+        awsAccessKeyId = awsAccessKeyId,
+        awsSecretKey = awsSecretKey)
       new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName),
         initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
-        cleanedHandler, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey)))
+        cleanedHandler, kinesisCredsProvider)
+    }
+  }
+
+  /**
+   * Create an input stream that pulls messages from a Kinesis stream.
+   * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
+   *
+   * @param ssc StreamingContext object
+   * @param kinesisAppName  Kinesis application name used by the Kinesis Client Library
+   *                        (KCL) to update DynamoDB
+   * @param streamName   Kinesis stream name
+   * @param endpointUrl  Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
+   * @param regionName   Name of region used by the Kinesis Client Library (KCL) to update
+   *                     DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
+   * @param initialPositionInStream  In the absence of Kinesis checkpoint info, this is the
+   *                                 worker's initial starting position in the stream.
+   *                                 The values are either the beginning of the stream
+   *                                 per Kinesis' limit of 24 hours
+   *                                 (InitialPositionInStream.TRIM_HORIZON) or
+   *                                 the tip of the stream (InitialPositionInStream.LATEST).
+   * @param checkpointInterval  Checkpoint interval for Kinesis checkpointing.
+   *                            See the Kinesis Spark Streaming documentation for more
+   *                            details on the different types of checkpoints.
+   * @param storageLevel Storage level to use for storing the received objects.
+   *                     StorageLevel.MEMORY_AND_DISK_2 is recommended.
+   * @param messageHandler A custom message handler that can generate a generic output from a
+   *                       Kinesis `Record`, which contains both message data, and metadata.
+   * @param awsAccessKeyId  AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain)
+   * @param awsSecretKey  AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain)
+   * @param stsAssumeRoleArn ARN of IAM role to assume when using STS sessions to read from
+   *                         Kinesis stream.
+   * @param stsSessionName Name to uniquely identify STS sessions if multiple princples assume
+   *                       the same role.
+   * @param stsExternalId External ID that can be used to validate against the assumed IAM role's
+   *                      trust policy.
+   *
+   * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing
+   * is enabled. Make sure that your checkpoint directory is secure.
+   */
+  // scalastyle:off
+  def createStream[T: ClassTag](
+      ssc: StreamingContext,
+      kinesisAppName: String,
+      streamName: String,
+      endpointUrl: String,
+      regionName: String,
+      initialPositionInStream: InitialPositionInStream,
+      checkpointInterval: Duration,
+      storageLevel: StorageLevel,
+      messageHandler: Record => T,
+      awsAccessKeyId: String,
+      awsSecretKey: String,
+      stsAssumeRoleArn: String,
+      stsSessionName: String,
+      stsExternalId: String): ReceiverInputDStream[T] = {
+    // scalastyle:on
+    val cleanedHandler = ssc.sc.clean(messageHandler)
+    ssc.withNamedScope("kinesis stream") {
+      val kinesisCredsProvider = STSCredentialsProvider(
+        stsRoleArn = stsAssumeRoleArn,
+        stsSessionName = stsSessionName,
+        stsExternalId = Option(stsExternalId),
+        longLivedCredsProvider = BasicCredentialsProvider(
+          awsAccessKeyId = awsAccessKeyId,
+          awsSecretKey = awsSecretKey))
+      new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName),
+        initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
+        cleanedHandler, kinesisCredsProvider)
     }
   }
 
@@ -169,7 +240,7 @@ object KinesisUtils {
     ssc.withNamedScope("kinesis stream") {
       new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName),
         initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
-        defaultMessageHandler, None)
+        defaultMessageHandler, DefaultCredentialsProvider)
     }
   }
 
@@ -213,9 +284,12 @@ object KinesisUtils {
       awsAccessKeyId: String,
       awsSecretKey: String): ReceiverInputDStream[Array[Byte]] = {
     ssc.withNamedScope("kinesis stream") {
+      val kinesisCredsProvider = BasicCredentialsProvider(
+        awsAccessKeyId = awsAccessKeyId,
+        awsSecretKey = awsSecretKey)
       new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName),
         initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
-        defaultMessageHandler, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey)))
+        defaultMessageHandler, kinesisCredsProvider)
     }
   }
 
@@ -319,6 +393,68 @@ object KinesisUtils {
       awsAccessKeyId, awsSecretKey)
   }
 
+  /**
+   * Create an input stream that pulls messages from a Kinesis stream.
+   * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
+   *
+   * @param jssc Java StreamingContext object
+   * @param kinesisAppName  Kinesis application name used by the Kinesis Client Library
+   *                        (KCL) to update DynamoDB
+   * @param streamName   Kinesis stream name
+   * @param endpointUrl  Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
+   * @param regionName   Name of region used by the Kinesis Client Library (KCL) to update
+   *                     DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
+   * @param initialPositionInStream  In the absence of Kinesis checkpoint info, this is the
+   *                                 worker's initial starting position in the stream.
+   *                                 The values are either the beginning of the stream
+   *                                 per Kinesis' limit of 24 hours
+   *                                 (InitialPositionInStream.TRIM_HORIZON) or
+   *                                 the tip of the stream (InitialPositionInStream.LATEST).
+   * @param checkpointInterval  Checkpoint interval for Kinesis checkpointing.
+   *                            See the Kinesis Spark Streaming documentation for more
+   *                            details on the different types of checkpoints.
+   * @param storageLevel Storage level to use for storing the received objects.
+   *                     StorageLevel.MEMORY_AND_DISK_2 is recommended.
+   * @param messageHandler A custom message handler that can generate a generic output from a
+   *                       Kinesis `Record`, which contains both message data, and metadata.
+   * @param recordClass Class of the records in DStream
+   * @param awsAccessKeyId  AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain)
+   * @param awsSecretKey  AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain)
+   * @param stsAssumeRoleArn ARN of IAM role to assume when using STS sessions to read from
+   *                         Kinesis stream.
+   * @param stsSessionName Name to uniquely identify STS sessions if multiple princples assume
+   *                       the same role.
+   * @param stsExternalId External ID that can be used to validate against the assumed IAM role's
+   *                      trust policy.
+   *
+   * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing
+   * is enabled. Make sure that your checkpoint directory is secure.
+   */
+  // scalastyle:off
+  def createStream[T](
+      jssc: JavaStreamingContext,
+      kinesisAppName: String,
+      streamName: String,
+      endpointUrl: String,
+      regionName: String,
+      initialPositionInStream: InitialPositionInStream,
+      checkpointInterval: Duration,
+      storageLevel: StorageLevel,
+      messageHandler: JFunction[Record, T],
+      recordClass: Class[T],
+      awsAccessKeyId: String,
+      awsSecretKey: String,
+      stsAssumeRoleArn: String,
+      stsSessionName: String,
+      stsExternalId: String): JavaReceiverInputDStream[T] = {
+    // scalastyle:on
+    implicit val recordCmt: ClassTag[T] = ClassTag(recordClass)
+    val cleanedHandler = jssc.sparkContext.clean(messageHandler.call(_))
+    createStream[T](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
+      initialPositionInStream, checkpointInterval, storageLevel, cleanedHandler,
+      awsAccessKeyId, awsSecretKey, stsAssumeRoleArn, stsSessionName, stsExternalId)
+  }
+
   /**
    * Create an input stream that pulls messages from a Kinesis stream.
    * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
@@ -404,10 +540,6 @@ object KinesisUtils {
       defaultMessageHandler(_), awsAccessKeyId, awsSecretKey)
   }
 
-  private def getRegionByEndpoint(endpointUrl: String): String = {
-    RegionUtils.getRegionByEndpoint(endpointUrl).getName()
-  }
-
   private def validateRegion(regionName: String): String = {
     Option(RegionUtils.getRegion(regionName)).map { _.getName }.getOrElse {
       throw new IllegalArgumentException(s"Region name '$regionName' is not valid")
@@ -439,6 +571,7 @@ private class KinesisUtilsPythonHelper {
     }
   }
 
+  // scalastyle:off
   def createStream(
       jssc: JavaStreamingContext,
       kinesisAppName: String,
@@ -449,22 +582,43 @@ private class KinesisUtilsPythonHelper {
       checkpointInterval: Duration,
       storageLevel: StorageLevel,
       awsAccessKeyId: String,
-      awsSecretKey: String
-      ): JavaReceiverInputDStream[Array[Byte]] = {
+      awsSecretKey: String,
+      stsAssumeRoleArn: String,
+      stsSessionName: String,
+      stsExternalId: String): JavaReceiverInputDStream[Array[Byte]] = {
+    // scalastyle:on
+    if (!(stsAssumeRoleArn != null && stsSessionName != null && stsExternalId != null)
+        && !(stsAssumeRoleArn == null && stsSessionName == null && stsExternalId == null)) {
+      throw new IllegalArgumentException("stsAssumeRoleArn, stsSessionName, and stsExtenalId " +
+        "must all be defined or all be null")
+    }
+
+    if (stsAssumeRoleArn != null && stsSessionName != null && stsExternalId != null) {
+      validateAwsCreds(awsAccessKeyId, awsSecretKey)
+      KinesisUtils.createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
+        getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel,
+        KinesisUtils.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey,
+        stsAssumeRoleArn, stsSessionName, stsExternalId)
+    } else {
+      validateAwsCreds(awsAccessKeyId, awsSecretKey)
+      if (awsAccessKeyId == null && awsSecretKey == null) {
+        KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName,
+          getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel)
+      } else {
+        KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName,
+          getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel,
+          awsAccessKeyId, awsSecretKey)
+      }
+    }
+  }
+
+  // Throw IllegalArgumentException unless both values are null or neither are.
+  private def validateAwsCreds(awsAccessKeyId: String, awsSecretKey: String) {
     if (awsAccessKeyId == null && awsSecretKey != null) {
       throw new IllegalArgumentException("awsSecretKey is set but awsAccessKeyId is null")
     }
     if (awsAccessKeyId != null && awsSecretKey == null) {
       throw new IllegalArgumentException("awsAccessKeyId is set but awsSecretKey is null")
     }
-    if (awsAccessKeyId == null && awsSecretKey == null) {
-      KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName,
-        getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel)
-    } else {
-      KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName,
-        getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel,
-        awsAccessKeyId, awsSecretKey)
-    }
   }
-
 }
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala
new file mode 100644
index 0000000000..aa6fe12edf
--- /dev/null
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.streaming.kinesis
+
+import scala.collection.JavaConverters._
+
+import com.amazonaws.auth._
+
+import org.apache.spark.internal.Logging
+
+/**
+ * Serializable interface providing a method executors can call to obtain an
+ * AWSCredentialsProvider instance for authenticating to AWS services.
+ */
+private[kinesis] sealed trait SerializableCredentialsProvider extends Serializable {
+  /**
+   * Return an AWSCredentialProvider instance that can be used by the Kinesis Client
+   * Library to authenticate to AWS services (Kinesis, CloudWatch and DynamoDB).
+   */
+  def provider: AWSCredentialsProvider
+}
+
+/** Returns DefaultAWSCredentialsProviderChain for authentication. */
+private[kinesis] final case object DefaultCredentialsProvider
+  extends SerializableCredentialsProvider {
+
+  def provider: AWSCredentialsProvider = new DefaultAWSCredentialsProviderChain
+}
+
+/**
+ * Returns AWSStaticCredentialsProvider constructed using basic AWS keypair. Falls back to using
+ * DefaultAWSCredentialsProviderChain if unable to construct a AWSCredentialsProviderChain
+ * instance with the provided arguments (e.g. if they are null).
+ */
+private[kinesis] final case class BasicCredentialsProvider(
+    awsAccessKeyId: String,
+    awsSecretKey: String) extends SerializableCredentialsProvider with Logging {
+
+  def provider: AWSCredentialsProvider = try {
+    new AWSStaticCredentialsProvider(new BasicAWSCredentials(awsAccessKeyId, awsSecretKey))
+  } catch {
+    case e: IllegalArgumentException =>
+      logWarning("Unable to construct AWSStaticCredentialsProvider with provided keypair; " +
+        "falling back to DefaultAWSCredentialsProviderChain.", e)
+      new DefaultAWSCredentialsProviderChain
+  }
+}
+
+/**
+ * Returns an STSAssumeRoleSessionCredentialsProvider instance which assumes an IAM
+ * role in order to authenticate against resources in an external account.
+ */
+private[kinesis] final case class STSCredentialsProvider(
+    stsRoleArn: String,
+    stsSessionName: String,
+    stsExternalId: Option[String] = None,
+    longLivedCredsProvider: SerializableCredentialsProvider = DefaultCredentialsProvider)
+  extends SerializableCredentialsProvider  {
+
+  def provider: AWSCredentialsProvider = {
+    val builder = new STSAssumeRoleSessionCredentialsProvider.Builder(stsRoleArn, stsSessionName)
+      .withLongLivedCredentialsProvider(longLivedCredsProvider.provider)
+    stsExternalId match {
+      case Some(stsExternalId) =>
+        builder.withExternalId(stsExternalId)
+          .build()
+      case None =>
+        builder.build()
+    }
+  }
+}
diff --git a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java
index f078973c6c..26b1fda2ff 100644
--- a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java
+++ b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java
@@ -36,7 +36,7 @@ public class JavaKinesisStreamSuite extends LocalJavaStreamingContext {
   @Test
   public void testKinesisStream() {
     String dummyEndpointUrl = KinesisTestUtils.defaultEndpointUrl();
-    String dummyRegionName = RegionUtils.getRegionByEndpoint(dummyEndpointUrl).getName();
+    String dummyRegionName = KinesisTestUtils.getRegionNameByEndpoint(dummyEndpointUrl);
 
     // Tests the API, does not actually test data receiving
     JavaDStream<byte[]> kinesisStream = KinesisUtils.createStream(ssc, "myAppName", "mySparkStream",
@@ -45,6 +45,17 @@ public class JavaKinesisStreamSuite extends LocalJavaStreamingContext {
     ssc.stop();
   }
 
+  @Test
+  public void testAwsCreds() {
+    String dummyEndpointUrl = KinesisTestUtils.defaultEndpointUrl();
+    String dummyRegionName = KinesisTestUtils.getRegionNameByEndpoint(dummyEndpointUrl);
+
+    // Tests the API, does not actually test data receiving
+    JavaDStream<byte[]> kinesisStream = KinesisUtils.createStream(ssc, "myAppName", "mySparkStream",
+        dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, new Duration(2000),
+        StorageLevel.MEMORY_AND_DISK_2(), "fakeAccessKey", "fakeSecretKey");
+    ssc.stop();
+  }
 
   private static Function<Record, String> handler = new Function<Record, String>() {
     @Override
@@ -62,4 +73,26 @@ public class JavaKinesisStreamSuite extends LocalJavaStreamingContext {
 
     ssc.stop();
   }
+
+  @Test
+  public void testCustomHandlerAwsCreds() {
+    // Tests the API, does not actually test data receiving
+    JavaDStream<String> kinesisStream = KinesisUtils.createStream(ssc, "testApp", "mySparkStream",
+        "https://kinesis.us-west-2.amazonaws.com", "us-west-2", InitialPositionInStream.LATEST,
+        new Duration(2000), StorageLevel.MEMORY_AND_DISK_2(), handler, String.class,
+        "fakeAccessKey", "fakeSecretKey");
+
+    ssc.stop();
+  }
+
+  @Test
+  public void testCustomHandlerAwsStsCreds() {
+    // Tests the API, does not actually test data receiving
+    JavaDStream<String> kinesisStream = KinesisUtils.createStream(ssc, "testApp", "mySparkStream",
+        "https://kinesis.us-west-2.amazonaws.com", "us-west-2", InitialPositionInStream.LATEST,
+        new Duration(2000), StorageLevel.MEMORY_AND_DISK_2(), handler, String.class,
+        "fakeAccessKey", "fakeSecretKey", "fakeSTSRoleArn", "fakeSTSSessionName", "fakeSTSExternalId");
+
+    ssc.stop();
+  }
 }
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
index 800502a77d..deb411d73e 100644
--- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
@@ -22,7 +22,7 @@ import java.util.Arrays
 
 import com.amazonaws.services.kinesis.clientlibrary.exceptions._
 import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer
-import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason
+import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason
 import com.amazonaws.services.kinesis.model.Record
 import org.mockito.Matchers._
 import org.mockito.Matchers.{eq => meq}
@@ -62,9 +62,26 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft
     checkpointerMock = mock[IRecordProcessorCheckpointer]
   }
 
-  test("check serializability of SerializableAWSCredentials") {
-    Utils.deserialize[SerializableAWSCredentials](
-      Utils.serialize(new SerializableAWSCredentials("x", "y")))
+  test("check serializability of credential provider classes") {
+    Utils.deserialize[BasicCredentialsProvider](
+      Utils.serialize(BasicCredentialsProvider(
+        awsAccessKeyId = "x",
+        awsSecretKey = "y")))
+
+    Utils.deserialize[STSCredentialsProvider](
+      Utils.serialize(STSCredentialsProvider(
+        stsRoleArn = "fakeArn",
+        stsSessionName = "fakeSessionName",
+        stsExternalId = Some("fakeExternalId"))))
+
+    Utils.deserialize[STSCredentialsProvider](
+      Utils.serialize(STSCredentialsProvider(
+        stsRoleArn = "fakeArn",
+        stsSessionName = "fakeSessionName",
+        stsExternalId = Some("fakeExternalId"),
+        longLivedCredsProvider = BasicCredentialsProvider(
+          awsAccessKeyId = "x",
+          awsSecretKey = "y"))))
   }
 
   test("process records including store and set checkpointer") {
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
index 404b673c01..387a96f26b 100644
--- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
@@ -49,7 +49,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun
 
   // Dummy parameters for API testing
   private val dummyEndpointUrl = defaultEndpointUrl
-  private val dummyRegionName = RegionUtils.getRegionByEndpoint(dummyEndpointUrl).getName()
+  private val dummyRegionName = KinesisTestUtils.getRegionNameByEndpoint(dummyEndpointUrl)
   private val dummyAWSAccessKey = "dummyAccessKey"
   private val dummyAWSSecretKey = "dummySecretKey"
 
@@ -138,8 +138,9 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun
     assert(kinesisRDD.regionName === dummyRegionName)
     assert(kinesisRDD.endpointUrl === dummyEndpointUrl)
     assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds)
-    assert(kinesisRDD.awsCredentialsOption ===
-      Some(SerializableAWSCredentials(dummyAWSAccessKey, dummyAWSSecretKey)))
+    assert(kinesisRDD.kinesisCredsProvider === BasicCredentialsProvider(
+      awsAccessKeyId = dummyAWSAccessKey,
+      awsSecretKey = dummyAWSSecretKey))
     assert(nonEmptyRDD.partitions.size === blockInfos.size)
     nonEmptyRDD.partitions.foreach { _ shouldBe a [KinesisBackedBlockRDDPartition] }
     val partitions = nonEmptyRDD.partitions.map {
@@ -201,7 +202,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun
     def addFive(r: Record): Int = JavaUtils.bytesToString(r.getData).toInt + 5
     val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName,
       testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST,
-      Seconds(10), StorageLevel.MEMORY_ONLY, addFive,
+      Seconds(10), StorageLevel.MEMORY_ONLY, addFive(_),
       awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey)
 
     stream shouldBe a [ReceiverInputDStream[_]]
diff --git a/pom.xml b/pom.xml
index 60e4c7269e..c1174593c1 100644
--- a/pom.xml
+++ b/pom.xml
@@ -145,7 +145,9 @@
     <avro.version>1.7.7</avro.version>
     <avro.mapred.classifier>hadoop2</avro.mapred.classifier>
     <jets3t.version>0.9.3</jets3t.version>
-    <aws.kinesis.client.version>1.6.2</aws.kinesis.client.version>
+    <aws.kinesis.client.version>1.7.3</aws.kinesis.client.version>
+    <!-- Should be consistent with Kinesis client dependency -->
+    <aws.java.sdk.version>1.11.76</aws.java.sdk.version>
     <!-- the producer is used in tests -->
     <aws.kinesis.producer.version>0.10.2</aws.kinesis.producer.version>
     <!--  org.apache.httpcomponents/httpclient-->
diff --git a/python/pyspark/streaming/kinesis.py b/python/pyspark/streaming/kinesis.py
index 3a8d8b819f..b839859c45 100644
--- a/python/pyspark/streaming/kinesis.py
+++ b/python/pyspark/streaming/kinesis.py
@@ -37,7 +37,8 @@ class KinesisUtils(object):
     def createStream(ssc, kinesisAppName, streamName, endpointUrl, regionName,
                      initialPositionInStream, checkpointInterval,
                      storageLevel=StorageLevel.MEMORY_AND_DISK_2,
-                     awsAccessKeyId=None, awsSecretKey=None, decoder=utf8_decoder):
+                     awsAccessKeyId=None, awsSecretKey=None, decoder=utf8_decoder,
+                     stsAssumeRoleArn=None, stsSessionName=None, stsExternalId=None):
         """
         Create an input stream that pulls messages from a Kinesis stream. This uses the
         Kinesis Client Library (KCL) to pull messages from Kinesis.
@@ -67,6 +68,12 @@ class KinesisUtils(object):
         :param awsSecretKey:  AWS SecretKey (default is None. If None, will use
                               DefaultAWSCredentialsProviderChain)
         :param decoder:  A function used to decode value (default is utf8_decoder)
+        :param stsAssumeRoleArn: ARN of IAM role to assume when using STS sessions to read from
+                                 the Kinesis stream (default is None).
+        :param stsSessionName: Name to uniquely identify STS sessions used to read from Kinesis
+                               stream, if STS is being used (default is None).
+        :param stsExternalId: External ID that can be used to validate against the assumed IAM
+                              role's trust policy, if STS is being used (default is None).
         :return: A DStream object
         """
         jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
@@ -81,7 +88,8 @@ class KinesisUtils(object):
             raise
         jstream = helper.createStream(ssc._jssc, kinesisAppName, streamName, endpointUrl,
                                       regionName, initialPositionInStream, jduration, jlevel,
-                                      awsAccessKeyId, awsSecretKey)
+                                      awsAccessKeyId, awsSecretKey, stsAssumeRoleArn,
+                                      stsSessionName, stsExternalId)
         stream = DStream(jstream, ssc, NoOpSerializer())
         return stream.map(lambda v: decoder(v))
 
-- 
GitLab