diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index b2bac7c938ab5cbdc5916316c4700d2ad715b51f..daa79e79163b9f4d7c6d5efb95219bbef445ac59 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -58,6 +58,11 @@ <artifactId>amazon-kinesis-client</artifactId> <version>${aws.kinesis.client.version}</version> </dependency> + <dependency> + <groupId>com.amazonaws</groupId> + <artifactId>aws-java-sdk-sts</artifactId> + <version>${aws.java.sdk.version}</version> + </dependency> <dependency> <groupId>com.amazonaws</groupId> <artifactId>amazon-kinesis-producer</artifactId> diff --git a/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java index d40bd3ff560d6fb603d19f14213b2cbb16dfa718..d1274a687fc701eb63e4e4e30b00e76a5a239cc8 100644 --- a/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java +++ b/external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -127,7 +127,7 @@ public final class JavaKinesisWordCountASL { // needs to be public for access fr // Get the region name from the endpoint URL to save Kinesis Client Library metadata in // DynamoDB of the same region as the Kinesis stream - String regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName(); + String regionName = KinesisExampleUtils.getRegionNameByEndpoint(endpointUrl); // Setup the Spark config and StreamingContext SparkConf sparkConfig = new SparkConf().setAppName("JavaKinesisWordCountASL"); diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisExampleUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisExampleUtils.scala new file mode 100644 index 0000000000000000000000000000000000000000..2eebd6130d4da7624f03d312e460c4eece2276aa --- /dev/null +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisExampleUtils.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming + +import scala.collection.JavaConverters._ + +import com.amazonaws.regions.RegionUtils +import com.amazonaws.services.kinesis.AmazonKinesis + +private[streaming] object KinesisExampleUtils { + def getRegionNameByEndpoint(endpoint: String): String = { + val uri = new java.net.URI(endpoint) + RegionUtils.getRegionsForService(AmazonKinesis.ENDPOINT_PREFIX) + .asScala + .find(_.getAvailableEndpoints.asScala.toSeq.contains(uri.getHost)) + .map(_.getName) + .getOrElse( + throw new IllegalArgumentException(s"Could not resolve region for endpoint: $endpoint")) + } +} diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index a70c13d7d68a83452ae475398851d8816b781610..f14117b708a0d1b3c87eccd63a14fda6a6993304 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -127,7 +127,7 @@ object KinesisWordCountASL extends Logging { // Get the region name from the endpoint URL to save Kinesis Client Library metadata in // DynamoDB of the same region as the Kinesis stream - val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName() + val regionName = KinesisExampleUtils.getRegionNameByEndpoint(endpointUrl) // Setup the SparkConfig and StreamingContext val sparkConfig = new SparkConf().setAppName("KinesisWordCountASL") diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index 45dc3c388cb8dcce7424d460032994cbdefdd122..23c4d99e50f5188a77429148779a6bd27e7a6348 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -79,7 +79,7 @@ class KinesisBackedBlockRDD[T: ClassTag]( @transient private val isBlockIdValid: Array[Boolean] = Array.empty, val retryTimeoutMs: Int = 10000, val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _, - val awsCredentialsOption: Option[SerializableAWSCredentials] = None + val kinesisCredsProvider: SerializableCredentialsProvider = DefaultCredentialsProvider ) extends BlockRDD[T](sc, _blockIds) { require(_blockIds.length == arrayOfseqNumberRanges.length, @@ -105,9 +105,7 @@ class KinesisBackedBlockRDD[T: ClassTag]( } def getBlockFromKinesis(): Iterator[T] = { - val credentials = awsCredentialsOption.getOrElse { - new DefaultAWSCredentialsProviderChain().getCredentials() - } + val credentials = kinesisCredsProvider.provider.getCredentials partition.seqNumberRanges.ranges.iterator.flatMap { range => new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName, range, retryTimeoutMs).map(messageHandler) @@ -143,7 +141,7 @@ class KinesisSequenceRangeIterator( private var lastSeqNumber: String = null private var internalIterator: Iterator[Record] = null - client.setEndpoint(endpointUrl, "kinesis", regionId) + client.setEndpoint(endpointUrl) override protected def getNext(): Record = { var nextRecord: Record = null diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala index c445c15a5f644367c05e850c9738ddcc9ea572f2..5fb83b26f83824390d30c517003eae968494209f 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala @@ -21,7 +21,7 @@ import java.util.concurrent._ import scala.util.control.NonFatal import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer -import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason import org.apache.spark.internal.Logging import org.apache.spark.streaming.Duration diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index 5223c81a8e0e07b3e8de885913b39bf39644e2ac..fbc6b99443ed70a3386f41346b0fddb58fece47b 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -39,7 +39,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( checkpointInterval: Duration, storageLevel: StorageLevel, messageHandler: Record => T, - awsCredentialsOption: Option[SerializableAWSCredentials] + kinesisCredsProvider: SerializableCredentialsProvider ) extends ReceiverInputDStream[T](_ssc) { private[streaming] @@ -61,7 +61,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( isBlockIdValid = isBlockIdValid, retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt, messageHandler = messageHandler, - awsCredentialsOption = awsCredentialsOption) + kinesisCredsProvider = kinesisCredsProvider) } else { logWarning("Kinesis sequence number information was not present with some block metadata," + " it may not be possible to recover from failures") @@ -71,6 +71,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( override def getReceiver(): Receiver[T] = { new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream, - checkpointAppName, checkpointInterval, storageLevel, messageHandler, awsCredentialsOption) + checkpointAppName, checkpointInterval, storageLevel, messageHandler, + kinesisCredsProvider) } } diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 393e56a39320c7670eb063278392c2d1bbd63ffa..13fc54e531ddaf512bd254944cfca2fc17c0e50e 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -23,7 +23,6 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.control.NonFatal -import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, DefaultAWSCredentialsProviderChain} import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer, IRecordProcessorFactory} import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker} import com.amazonaws.services.kinesis.model.Record @@ -34,13 +33,6 @@ import org.apache.spark.streaming.Duration import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver} import org.apache.spark.util.Utils -private[kinesis] -case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) - extends AWSCredentials { - override def getAWSAccessKeyId: String = accessKeyId - override def getAWSSecretKey: String = secretKey -} - /** * Custom AWS Kinesis-specific implementation of Spark Streaming's Receiver. * This implementation relies on the Kinesis Client Library (KCL) Worker as described here: @@ -78,8 +70,9 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) * See the Kinesis Spark Streaming documentation for more * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects - * @param awsCredentialsOption Optional AWS credentials, used when user directly specifies - * the credentials + * @param kinesisCredsProvider SerializableCredentialsProvider instance that will be used to + * generate the AWSCredentialsProvider instance used for KCL + * authorization. */ private[kinesis] class KinesisReceiver[T]( val streamName: String, @@ -90,7 +83,7 @@ private[kinesis] class KinesisReceiver[T]( checkpointInterval: Duration, storageLevel: StorageLevel, messageHandler: Record => T, - awsCredentialsOption: Option[SerializableAWSCredentials]) + kinesisCredsProvider: SerializableCredentialsProvider) extends Receiver[T](storageLevel) with Logging { receiver => /* @@ -147,14 +140,15 @@ private[kinesis] class KinesisReceiver[T]( workerId = Utils.localHostName() + ":" + UUID.randomUUID() kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, workerId) - // KCL config instance - val awsCredProvider = resolveAWSCredentialsProvider() - val kinesisClientLibConfiguration = - new KinesisClientLibConfiguration(checkpointAppName, streamName, awsCredProvider, workerId) - .withKinesisEndpoint(endpointUrl) - .withInitialPositionInStream(initialPositionInStream) - .withTaskBackoffTimeMillis(500) - .withRegionName(regionName) + val kinesisClientLibConfiguration = new KinesisClientLibConfiguration( + checkpointAppName, + streamName, + kinesisCredsProvider.provider, + workerId) + .withKinesisEndpoint(endpointUrl) + .withInitialPositionInStream(initialPositionInStream) + .withTaskBackoffTimeMillis(500) + .withRegionName(regionName) /* * RecordProcessorFactory creates impls of IRecordProcessor. @@ -305,25 +299,6 @@ private[kinesis] class KinesisReceiver[T]( } } - /** - * If AWS credential is provided, return a AWSCredentialProvider returning that credential. - * Otherwise, return the DefaultAWSCredentialsProviderChain. - */ - private def resolveAWSCredentialsProvider(): AWSCredentialsProvider = { - awsCredentialsOption match { - case Some(awsCredentials) => - logInfo("Using provided AWS credentials") - new AWSCredentialsProvider { - override def getCredentials: AWSCredentials = awsCredentials - override def refresh(): Unit = { } - } - case None => - logInfo("Using DefaultAWSCredentialsProviderChain") - new DefaultAWSCredentialsProviderChain() - } - } - - /** * Class to handle blocks generated by this receiver's block generator. Specifically, in * the context of the Kinesis Receiver, this handler does the following. diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index 73ccc4ad23f6d39002bc0388b9bce892a7a7eeb3..8c6a399dd763ea6e3245d1113630a179782d519c 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -23,7 +23,7 @@ import scala.util.control.NonFatal import com.amazonaws.services.kinesis.clientlibrary.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException} import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer} -import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason import com.amazonaws.services.kinesis.model.Record import org.apache.spark.internal.Logging diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index f183ef00b33cdbf010a47cb6d42ca36c553a5294..73ac7a3cd2355e78447d0c0f374b3285a499b1ab 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -30,7 +30,7 @@ import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} import com.amazonaws.regions.RegionUtils import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient import com.amazonaws.services.dynamodbv2.document.DynamoDB -import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.{AmazonKinesis, AmazonKinesisClient} import com.amazonaws.services.kinesis.model._ import org.apache.spark.internal.Logging @@ -43,7 +43,7 @@ import org.apache.spark.internal.Logging private[kinesis] class KinesisTestUtils(streamShardCount: Int = 2) extends Logging { val endpointUrl = KinesisTestUtils.endpointUrl - val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName() + val regionName = KinesisTestUtils.getRegionNameByEndpoint(endpointUrl) private val createStreamTimeoutSeconds = 300 private val describeStreamPollTimeSeconds = 1 @@ -205,6 +205,16 @@ private[kinesis] object KinesisTestUtils { val endVarNameForEndpoint = "KINESIS_TEST_ENDPOINT_URL" val defaultEndpointUrl = "https://kinesis.us-west-2.amazonaws.com" + def getRegionNameByEndpoint(endpoint: String): String = { + val uri = new java.net.URI(endpoint) + RegionUtils.getRegionsForService(AmazonKinesis.ENDPOINT_PREFIX) + .asScala + .find(_.getAvailableEndpoints.asScala.toSeq.contains(uri.getHost)) + .map(_.getName) + .getOrElse( + throw new IllegalArgumentException(s"Could not resolve region for endpoint: $endpoint")) + } + lazy val shouldRunTests = { val isEnvSet = sys.env.get(envVarNameForEnablingTests) == Some("1") if (isEnvSet) { diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index b2daffa34ccbfaf540aa622ebcc178f5ca9954e5..2d777982e760c06de11d2f0203379382d5cd8629 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -73,7 +73,7 @@ object KinesisUtils { ssc.withNamedScope("kinesis stream") { new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, None) + cleanedHandler, DefaultCredentialsProvider) } } @@ -123,9 +123,80 @@ object KinesisUtils { // scalastyle:on val cleanedHandler = ssc.sc.clean(messageHandler) ssc.withNamedScope("kinesis stream") { + val kinesisCredsProvider = BasicCredentialsProvider( + awsAccessKeyId = awsAccessKeyId, + awsSecretKey = awsSecretKey) new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - cleanedHandler, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) + cleanedHandler, kinesisCredsProvider) + } + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * @param ssc StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param messageHandler A custom message handler that can generate a generic output from a + * Kinesis `Record`, which contains both message data, and metadata. + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * @param stsAssumeRoleArn ARN of IAM role to assume when using STS sessions to read from + * Kinesis stream. + * @param stsSessionName Name to uniquely identify STS sessions if multiple princples assume + * the same role. + * @param stsExternalId External ID that can be used to validate against the assumed IAM role's + * trust policy. + * + * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. + */ + // scalastyle:off + def createStream[T: ClassTag]( + ssc: StreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + messageHandler: Record => T, + awsAccessKeyId: String, + awsSecretKey: String, + stsAssumeRoleArn: String, + stsSessionName: String, + stsExternalId: String): ReceiverInputDStream[T] = { + // scalastyle:on + val cleanedHandler = ssc.sc.clean(messageHandler) + ssc.withNamedScope("kinesis stream") { + val kinesisCredsProvider = STSCredentialsProvider( + stsRoleArn = stsAssumeRoleArn, + stsSessionName = stsSessionName, + stsExternalId = Option(stsExternalId), + longLivedCredsProvider = BasicCredentialsProvider( + awsAccessKeyId = awsAccessKeyId, + awsSecretKey = awsSecretKey)) + new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + cleanedHandler, kinesisCredsProvider) } } @@ -169,7 +240,7 @@ object KinesisUtils { ssc.withNamedScope("kinesis stream") { new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - defaultMessageHandler, None) + defaultMessageHandler, DefaultCredentialsProvider) } } @@ -213,9 +284,12 @@ object KinesisUtils { awsAccessKeyId: String, awsSecretKey: String): ReceiverInputDStream[Array[Byte]] = { ssc.withNamedScope("kinesis stream") { + val kinesisCredsProvider = BasicCredentialsProvider( + awsAccessKeyId = awsAccessKeyId, + awsSecretKey = awsSecretKey) new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - defaultMessageHandler, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) + defaultMessageHandler, kinesisCredsProvider) } } @@ -319,6 +393,68 @@ object KinesisUtils { awsAccessKeyId, awsSecretKey) } + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * @param jssc Java StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param messageHandler A custom message handler that can generate a generic output from a + * Kinesis `Record`, which contains both message data, and metadata. + * @param recordClass Class of the records in DStream + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * @param stsAssumeRoleArn ARN of IAM role to assume when using STS sessions to read from + * Kinesis stream. + * @param stsSessionName Name to uniquely identify STS sessions if multiple princples assume + * the same role. + * @param stsExternalId External ID that can be used to validate against the assumed IAM role's + * trust policy. + * + * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. + */ + // scalastyle:off + def createStream[T]( + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + messageHandler: JFunction[Record, T], + recordClass: Class[T], + awsAccessKeyId: String, + awsSecretKey: String, + stsAssumeRoleArn: String, + stsSessionName: String, + stsExternalId: String): JavaReceiverInputDStream[T] = { + // scalastyle:on + implicit val recordCmt: ClassTag[T] = ClassTag(recordClass) + val cleanedHandler = jssc.sparkContext.clean(messageHandler.call(_)) + createStream[T](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel, cleanedHandler, + awsAccessKeyId, awsSecretKey, stsAssumeRoleArn, stsSessionName, stsExternalId) + } + /** * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. @@ -404,10 +540,6 @@ object KinesisUtils { defaultMessageHandler(_), awsAccessKeyId, awsSecretKey) } - private def getRegionByEndpoint(endpointUrl: String): String = { - RegionUtils.getRegionByEndpoint(endpointUrl).getName() - } - private def validateRegion(regionName: String): String = { Option(RegionUtils.getRegion(regionName)).map { _.getName }.getOrElse { throw new IllegalArgumentException(s"Region name '$regionName' is not valid") @@ -439,6 +571,7 @@ private class KinesisUtilsPythonHelper { } } + // scalastyle:off def createStream( jssc: JavaStreamingContext, kinesisAppName: String, @@ -449,22 +582,43 @@ private class KinesisUtilsPythonHelper { checkpointInterval: Duration, storageLevel: StorageLevel, awsAccessKeyId: String, - awsSecretKey: String - ): JavaReceiverInputDStream[Array[Byte]] = { + awsSecretKey: String, + stsAssumeRoleArn: String, + stsSessionName: String, + stsExternalId: String): JavaReceiverInputDStream[Array[Byte]] = { + // scalastyle:on + if (!(stsAssumeRoleArn != null && stsSessionName != null && stsExternalId != null) + && !(stsAssumeRoleArn == null && stsSessionName == null && stsExternalId == null)) { + throw new IllegalArgumentException("stsAssumeRoleArn, stsSessionName, and stsExtenalId " + + "must all be defined or all be null") + } + + if (stsAssumeRoleArn != null && stsSessionName != null && stsExternalId != null) { + validateAwsCreds(awsAccessKeyId, awsSecretKey) + KinesisUtils.createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel, + KinesisUtils.defaultMessageHandler(_), awsAccessKeyId, awsSecretKey, + stsAssumeRoleArn, stsSessionName, stsExternalId) + } else { + validateAwsCreds(awsAccessKeyId, awsSecretKey) + if (awsAccessKeyId == null && awsSecretKey == null) { + KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, + getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel) + } else { + KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, + getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel, + awsAccessKeyId, awsSecretKey) + } + } + } + + // Throw IllegalArgumentException unless both values are null or neither are. + private def validateAwsCreds(awsAccessKeyId: String, awsSecretKey: String) { if (awsAccessKeyId == null && awsSecretKey != null) { throw new IllegalArgumentException("awsSecretKey is set but awsAccessKeyId is null") } if (awsAccessKeyId != null && awsSecretKey == null) { throw new IllegalArgumentException("awsAccessKeyId is set but awsSecretKey is null") } - if (awsAccessKeyId == null && awsSecretKey == null) { - KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, - getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel) - } else { - KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, - getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel, - awsAccessKeyId, awsSecretKey) - } } - } diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala new file mode 100644 index 0000000000000000000000000000000000000000..aa6fe12edf74e71b4abee29b3ed404f74db81daf --- /dev/null +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SerializableCredentialsProvider.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import scala.collection.JavaConverters._ + +import com.amazonaws.auth._ + +import org.apache.spark.internal.Logging + +/** + * Serializable interface providing a method executors can call to obtain an + * AWSCredentialsProvider instance for authenticating to AWS services. + */ +private[kinesis] sealed trait SerializableCredentialsProvider extends Serializable { + /** + * Return an AWSCredentialProvider instance that can be used by the Kinesis Client + * Library to authenticate to AWS services (Kinesis, CloudWatch and DynamoDB). + */ + def provider: AWSCredentialsProvider +} + +/** Returns DefaultAWSCredentialsProviderChain for authentication. */ +private[kinesis] final case object DefaultCredentialsProvider + extends SerializableCredentialsProvider { + + def provider: AWSCredentialsProvider = new DefaultAWSCredentialsProviderChain +} + +/** + * Returns AWSStaticCredentialsProvider constructed using basic AWS keypair. Falls back to using + * DefaultAWSCredentialsProviderChain if unable to construct a AWSCredentialsProviderChain + * instance with the provided arguments (e.g. if they are null). + */ +private[kinesis] final case class BasicCredentialsProvider( + awsAccessKeyId: String, + awsSecretKey: String) extends SerializableCredentialsProvider with Logging { + + def provider: AWSCredentialsProvider = try { + new AWSStaticCredentialsProvider(new BasicAWSCredentials(awsAccessKeyId, awsSecretKey)) + } catch { + case e: IllegalArgumentException => + logWarning("Unable to construct AWSStaticCredentialsProvider with provided keypair; " + + "falling back to DefaultAWSCredentialsProviderChain.", e) + new DefaultAWSCredentialsProviderChain + } +} + +/** + * Returns an STSAssumeRoleSessionCredentialsProvider instance which assumes an IAM + * role in order to authenticate against resources in an external account. + */ +private[kinesis] final case class STSCredentialsProvider( + stsRoleArn: String, + stsSessionName: String, + stsExternalId: Option[String] = None, + longLivedCredsProvider: SerializableCredentialsProvider = DefaultCredentialsProvider) + extends SerializableCredentialsProvider { + + def provider: AWSCredentialsProvider = { + val builder = new STSAssumeRoleSessionCredentialsProvider.Builder(stsRoleArn, stsSessionName) + .withLongLivedCredentialsProvider(longLivedCredsProvider.provider) + stsExternalId match { + case Some(stsExternalId) => + builder.withExternalId(stsExternalId) + .build() + case None => + builder.build() + } + } +} diff --git a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java index f078973c6c285a9e4d7de24806c5ffa2104bef59..26b1fda2ff5114cd91cc4978175252d308557ed0 100644 --- a/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java +++ b/external/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java @@ -36,7 +36,7 @@ public class JavaKinesisStreamSuite extends LocalJavaStreamingContext { @Test public void testKinesisStream() { String dummyEndpointUrl = KinesisTestUtils.defaultEndpointUrl(); - String dummyRegionName = RegionUtils.getRegionByEndpoint(dummyEndpointUrl).getName(); + String dummyRegionName = KinesisTestUtils.getRegionNameByEndpoint(dummyEndpointUrl); // Tests the API, does not actually test data receiving JavaDStream<byte[]> kinesisStream = KinesisUtils.createStream(ssc, "myAppName", "mySparkStream", @@ -45,6 +45,17 @@ public class JavaKinesisStreamSuite extends LocalJavaStreamingContext { ssc.stop(); } + @Test + public void testAwsCreds() { + String dummyEndpointUrl = KinesisTestUtils.defaultEndpointUrl(); + String dummyRegionName = KinesisTestUtils.getRegionNameByEndpoint(dummyEndpointUrl); + + // Tests the API, does not actually test data receiving + JavaDStream<byte[]> kinesisStream = KinesisUtils.createStream(ssc, "myAppName", "mySparkStream", + dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, new Duration(2000), + StorageLevel.MEMORY_AND_DISK_2(), "fakeAccessKey", "fakeSecretKey"); + ssc.stop(); + } private static Function<Record, String> handler = new Function<Record, String>() { @Override @@ -62,4 +73,26 @@ public class JavaKinesisStreamSuite extends LocalJavaStreamingContext { ssc.stop(); } + + @Test + public void testCustomHandlerAwsCreds() { + // Tests the API, does not actually test data receiving + JavaDStream<String> kinesisStream = KinesisUtils.createStream(ssc, "testApp", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", InitialPositionInStream.LATEST, + new Duration(2000), StorageLevel.MEMORY_AND_DISK_2(), handler, String.class, + "fakeAccessKey", "fakeSecretKey"); + + ssc.stop(); + } + + @Test + public void testCustomHandlerAwsStsCreds() { + // Tests the API, does not actually test data receiving + JavaDStream<String> kinesisStream = KinesisUtils.createStream(ssc, "testApp", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", InitialPositionInStream.LATEST, + new Duration(2000), StorageLevel.MEMORY_AND_DISK_2(), handler, String.class, + "fakeAccessKey", "fakeSecretKey", "fakeSTSRoleArn", "fakeSTSSessionName", "fakeSTSExternalId"); + + ssc.stop(); + } } diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 800502a77d120f9180ab06f359d794fdb974760c..deb411d73e5889b02007dc71f0cd860fc9d507e7 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -22,7 +22,7 @@ import java.util.Arrays import com.amazonaws.services.kinesis.clientlibrary.exceptions._ import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer -import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason import com.amazonaws.services.kinesis.model.Record import org.mockito.Matchers._ import org.mockito.Matchers.{eq => meq} @@ -62,9 +62,26 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft checkpointerMock = mock[IRecordProcessorCheckpointer] } - test("check serializability of SerializableAWSCredentials") { - Utils.deserialize[SerializableAWSCredentials]( - Utils.serialize(new SerializableAWSCredentials("x", "y"))) + test("check serializability of credential provider classes") { + Utils.deserialize[BasicCredentialsProvider]( + Utils.serialize(BasicCredentialsProvider( + awsAccessKeyId = "x", + awsSecretKey = "y"))) + + Utils.deserialize[STSCredentialsProvider]( + Utils.serialize(STSCredentialsProvider( + stsRoleArn = "fakeArn", + stsSessionName = "fakeSessionName", + stsExternalId = Some("fakeExternalId")))) + + Utils.deserialize[STSCredentialsProvider]( + Utils.serialize(STSCredentialsProvider( + stsRoleArn = "fakeArn", + stsSessionName = "fakeSessionName", + stsExternalId = Some("fakeExternalId"), + longLivedCredsProvider = BasicCredentialsProvider( + awsAccessKeyId = "x", + awsSecretKey = "y")))) } test("process records including store and set checkpointer") { diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 404b673c011718a96c96dc5fb09a1bad2f7d2324..387a96f26b3058442b83c6238734e094e756adf5 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -49,7 +49,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun // Dummy parameters for API testing private val dummyEndpointUrl = defaultEndpointUrl - private val dummyRegionName = RegionUtils.getRegionByEndpoint(dummyEndpointUrl).getName() + private val dummyRegionName = KinesisTestUtils.getRegionNameByEndpoint(dummyEndpointUrl) private val dummyAWSAccessKey = "dummyAccessKey" private val dummyAWSSecretKey = "dummySecretKey" @@ -138,8 +138,9 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun assert(kinesisRDD.regionName === dummyRegionName) assert(kinesisRDD.endpointUrl === dummyEndpointUrl) assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds) - assert(kinesisRDD.awsCredentialsOption === - Some(SerializableAWSCredentials(dummyAWSAccessKey, dummyAWSSecretKey))) + assert(kinesisRDD.kinesisCredsProvider === BasicCredentialsProvider( + awsAccessKeyId = dummyAWSAccessKey, + awsSecretKey = dummyAWSSecretKey)) assert(nonEmptyRDD.partitions.size === blockInfos.size) nonEmptyRDD.partitions.foreach { _ shouldBe a [KinesisBackedBlockRDDPartition] } val partitions = nonEmptyRDD.partitions.map { @@ -201,7 +202,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun def addFive(r: Record): Int = JavaUtils.bytesToString(r.getData).toInt + 5 val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, - Seconds(10), StorageLevel.MEMORY_ONLY, addFive, + Seconds(10), StorageLevel.MEMORY_ONLY, addFive(_), awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) stream shouldBe a [ReceiverInputDStream[_]] diff --git a/pom.xml b/pom.xml index 60e4c7269eafd51caaf82994aa4b4544f6a073a7..c1174593c19220d2511df0c92b219f1802d5fb80 100644 --- a/pom.xml +++ b/pom.xml @@ -145,7 +145,9 @@ <avro.version>1.7.7</avro.version> <avro.mapred.classifier>hadoop2</avro.mapred.classifier> <jets3t.version>0.9.3</jets3t.version> - <aws.kinesis.client.version>1.6.2</aws.kinesis.client.version> + <aws.kinesis.client.version>1.7.3</aws.kinesis.client.version> + <!-- Should be consistent with Kinesis client dependency --> + <aws.java.sdk.version>1.11.76</aws.java.sdk.version> <!-- the producer is used in tests --> <aws.kinesis.producer.version>0.10.2</aws.kinesis.producer.version> <!-- org.apache.httpcomponents/httpclient--> diff --git a/python/pyspark/streaming/kinesis.py b/python/pyspark/streaming/kinesis.py index 3a8d8b819fd370d4cab828ac83f70c7537edd884..b839859c45252ae7ee472a1ab8fde27aa06233c0 100644 --- a/python/pyspark/streaming/kinesis.py +++ b/python/pyspark/streaming/kinesis.py @@ -37,7 +37,8 @@ class KinesisUtils(object): def createStream(ssc, kinesisAppName, streamName, endpointUrl, regionName, initialPositionInStream, checkpointInterval, storageLevel=StorageLevel.MEMORY_AND_DISK_2, - awsAccessKeyId=None, awsSecretKey=None, decoder=utf8_decoder): + awsAccessKeyId=None, awsSecretKey=None, decoder=utf8_decoder, + stsAssumeRoleArn=None, stsSessionName=None, stsExternalId=None): """ Create an input stream that pulls messages from a Kinesis stream. This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. @@ -67,6 +68,12 @@ class KinesisUtils(object): :param awsSecretKey: AWS SecretKey (default is None. If None, will use DefaultAWSCredentialsProviderChain) :param decoder: A function used to decode value (default is utf8_decoder) + :param stsAssumeRoleArn: ARN of IAM role to assume when using STS sessions to read from + the Kinesis stream (default is None). + :param stsSessionName: Name to uniquely identify STS sessions used to read from Kinesis + stream, if STS is being used (default is None). + :param stsExternalId: External ID that can be used to validate against the assumed IAM + role's trust policy, if STS is being used (default is None). :return: A DStream object """ jlevel = ssc._sc._getJavaStorageLevel(storageLevel) @@ -81,7 +88,8 @@ class KinesisUtils(object): raise jstream = helper.createStream(ssc._jssc, kinesisAppName, streamName, endpointUrl, regionName, initialPositionInStream, jduration, jlevel, - awsAccessKeyId, awsSecretKey) + awsAccessKeyId, awsSecretKey, stsAssumeRoleArn, + stsSessionName, stsExternalId) stream = DStream(jstream, ssc, NoOpSerializer()) return stream.map(lambda v: decoder(v))