From 0f1d00a905614bb5eebf260566dbcb831158d445 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu <shixiong@databricks.com> Date: Thu, 12 Nov 2015 17:48:43 -0800 Subject: [PATCH] [SPARK-11663][STREAMING] Add Java API for trackStateByKey TODO - [x] Add Java API - [x] Add API tests - [x] Add a function test Author: Shixiong Zhu <shixiong@databricks.com> Closes #9636 from zsxwing/java-track. --- .../spark/api/java/function/Function4.java | 27 +++ .../JavaStatefulNetworkWordCount.java | 45 ++-- .../streaming/StatefulNetworkWordCount.scala | 2 +- .../apache/spark/streaming/Java8APISuite.java | 43 ++++ .../org/apache/spark/streaming/State.scala | 25 ++- .../apache/spark/streaming/StateSpec.scala | 84 +++++-- .../streaming/api/java/JavaPairDStream.scala | 46 +++- .../api/java/JavaTrackStateDStream.scala | 44 ++++ .../streaming/dstream/TrackStateDStream.scala | 1 + .../spark/streaming/rdd/TrackStateRDD.scala | 4 +- .../spark/streaming/util/StateMap.scala | 6 +- .../streaming/JavaTrackStateByKeySuite.java | 210 ++++++++++++++++++ 12 files changed, 485 insertions(+), 52 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/java/function/Function4.java create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala create mode 100644 streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function4.java b/core/src/main/java/org/apache/spark/api/java/function/Function4.java new file mode 100644 index 0000000000..fd727d6486 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/Function4.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * A four-argument function that takes arguments of type T1, T2, T3 and T4 and returns an R. + */ +public interface Function4<T1, T2, T3, T4, R> extends Serializable { + public R call(T1 v1, T2 v2, T3 v3, T4 v4) throws Exception; +} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java index 99b63a2590..c400e4237a 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -26,18 +26,15 @@ import scala.Tuple2; import com.google.common.base.Optional; import com.google.common.collect.Lists; -import org.apache.spark.HashPartitioner; import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.StorageLevels; -import org.apache.spark.api.java.function.FlatMapFunction; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.streaming.Durations; -import org.apache.spark.streaming.api.java.JavaDStream; -import org.apache.spark.streaming.api.java.JavaPairDStream; -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; -import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.streaming.State; +import org.apache.spark.streaming.StateSpec; +import org.apache.spark.streaming.Time; +import org.apache.spark.streaming.api.java.*; /** * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every @@ -63,25 +60,12 @@ public class JavaStatefulNetworkWordCount { StreamingExamples.setStreamingLogLevels(); - // Update the cumulative count function - final Function2<List<Integer>, Optional<Integer>, Optional<Integer>> updateFunction = - new Function2<List<Integer>, Optional<Integer>, Optional<Integer>>() { - @Override - public Optional<Integer> call(List<Integer> values, Optional<Integer> state) { - Integer newSum = state.or(0); - for (Integer value : values) { - newSum += value; - } - return Optional.of(newSum); - } - }; - // Create the context with a 1 second batch size SparkConf sparkConf = new SparkConf().setAppName("JavaStatefulNetworkWordCount"); JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1)); ssc.checkpoint("."); - // Initial RDD input to updateStateByKey + // Initial RDD input to trackStateByKey @SuppressWarnings("unchecked") List<Tuple2<String, Integer>> tuples = Arrays.asList(new Tuple2<String, Integer>("hello", 1), new Tuple2<String, Integer>("world", 1)); @@ -105,9 +89,22 @@ public class JavaStatefulNetworkWordCount { } }); + // Update the cumulative count function + final Function4<Time, String, Optional<Integer>, State<Integer>, Optional<Tuple2<String, Integer>>> trackStateFunc = + new Function4<Time, String, Optional<Integer>, State<Integer>, Optional<Tuple2<String, Integer>>>() { + + @Override + public Optional<Tuple2<String, Integer>> call(Time time, String word, Optional<Integer> one, State<Integer> state) { + int sum = one.or(0) + (state.exists() ? state.get() : 0); + Tuple2<String, Integer> output = new Tuple2<String, Integer>(word, sum); + state.update(sum); + return Optional.of(output); + } + }; + // This will give a Dstream made of state (which is the cumulative count of the words) - JavaPairDStream<String, Integer> stateDstream = wordsDstream.updateStateByKey(updateFunction, - new HashPartitioner(ssc.sparkContext().defaultParallelism()), initialRDD); + JavaTrackStateDStream<String, Integer, Integer, Tuple2<String, Integer>> stateDstream = + wordsDstream.trackStateByKey(StateSpec.function(trackStateFunc).initialState(initialRDD)); stateDstream.print(); ssc.start(); diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index be2ae0b473..a4f847f118 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -49,7 +49,7 @@ object StatefulNetworkWordCount { val ssc = new StreamingContext(sparkConf, Seconds(1)) ssc.checkpoint(".") - // Initial RDD input to updateStateByKey + // Initial RDD input to trackStateByKey val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1))) // Create a ReceiverInputDStream on target ip:port and count the diff --git a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java index 73091cfe2c..163ae92c12 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java @@ -31,9 +31,12 @@ import org.junit.Test; import org.apache.spark.HashPartitioner; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.Function4; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaTrackStateDStream; /** * Most of these tests replicate org.apache.spark.streaming.JavaAPISuite using java 8 @@ -831,4 +834,44 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ Assert.assertEquals(expected, result); } + /** + * This test is only for testing the APIs. It's not necessary to run it. + */ + public void testTrackStateByAPI() { + JavaPairRDD<String, Boolean> initialRDD = null; + JavaPairDStream<String, Integer> wordsDstream = null; + + JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream = + wordsDstream.trackStateByKey( + StateSpec.<String, Integer, Boolean, Double> function((time, key, value, state) -> { + // Use all State's methods here + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return Optional.of(2.0); + }).initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); + + JavaPairDStream<String, Boolean> emittedRecords = stateDstream.stateSnapshots(); + + JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream2 = + wordsDstream.trackStateByKey( + StateSpec.<String, Integer, Boolean, Double>function((value, state) -> { + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return 2.0; + }).initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); + + JavaPairDStream<String, Boolean> emittedRecords2 = stateDstream2.stateSnapshots(); + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala index 7dd1b72f80..604e64fc61 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -50,9 +50,30 @@ import org.apache.spark.annotation.Experimental * * }}} * - * Java example: + * Java example of using `State`: * {{{ - * TODO(@zsxwing) + * // A tracking function that maintains an integer state and return a String + * Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc = + * new Function2<Optional<Integer>, State<Integer>, Optional<String>>() { + * + * @Override + * public Optional<String> call(Optional<Integer> one, State<Integer> state) { + * if (state.exists()) { + * int existingState = state.get(); // Get the existing state + * boolean shouldRemove = ...; // Decide whether to remove the state + * if (shouldRemove) { + * state.remove(); // Remove the state + * } else { + * int newState = ...; + * state.update(newState); // Set the new state + * } + * } else { + * int initialState = ...; // Set the initial state + * state.update(initialState); + * } + * // return something + * } + * }; * }}} */ @Experimental diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala index c9fe35e74c..bea5b9df20 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -17,15 +17,14 @@ package org.apache.spark.streaming -import scala.reflect.ClassTag - +import com.google.common.base.Optional import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.JavaPairRDD +import org.apache.spark.api.java.{JavaPairRDD, JavaUtils} +import org.apache.spark.api.java.function.{Function2 => JFunction2, Function4 => JFunction4} import org.apache.spark.rdd.RDD import org.apache.spark.util.ClosureCleaner import org.apache.spark.{HashPartitioner, Partitioner} - /** * :: Experimental :: * Abstract class representing all the specifications of the DStream transformation @@ -49,12 +48,12 @@ import org.apache.spark.{HashPartitioner, Partitioner} * * Example in Java: * {{{ - * StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec = - * StateStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction) + * StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec = + * StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction) * .numPartition(10); * - * JavaDStream[EmittedDataType] emittedRecordDStream = - * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec); + * JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream = + * javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec); * }}} */ @Experimental @@ -92,6 +91,7 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte /** * :: Experimental :: * Builder object for creating instances of [[org.apache.spark.streaming.StateSpec StateSpec]] + * that is used for specifying the parameters of the DStream transformation `trackStateByKey` * that is used for specifying the parameters of the DStream transformation * `trackStateByKey` operation of a * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a @@ -103,28 +103,27 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte * ... * } * - * val spec = StateSpec.function(trackingFunction).numPartitions(10) - * - * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec) + * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType]( + * StateSpec.function(trackingFunction).numPartitions(10)) * }}} * * Example in Java: * {{{ - * StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec = - * StateStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction) + * StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec = + * StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction) * .numPartition(10); * - * JavaDStream[EmittedDataType] emittedRecordDStream = - * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec); + * JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream = + * javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec); * }}} */ @Experimental object StateSpec { /** * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications - * `trackStateByKey` operation on a - * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a - * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * of the `trackStateByKey` operation on a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. + * * @param trackingFunction The function applied on every data item to manage the associated state * and generate the emitted data * @tparam KeyType Class of the keys @@ -141,9 +140,9 @@ object StateSpec { /** * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications - * `trackStateByKey` operation on a - * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a - * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * of the `trackStateByKey` operation on a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. + * * @param trackingFunction The function applied on every data item to manage the associated state * and generate the emitted data * @tparam ValueType Class of the values @@ -160,6 +159,48 @@ object StateSpec { } new StateSpecImpl(wrappedFunction) } + + /** + * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all + * the specifications of the `trackStateByKey` operation on a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]]. + * + * @param javaTrackingFunction The function applied on every data item to manage the associated + * state and generate the emitted data + * @tparam KeyType Class of the keys + * @tparam ValueType Class of the values + * @tparam StateType Class of the states data + * @tparam EmittedType Class of the emitted data + */ + def function[KeyType, ValueType, StateType, EmittedType](javaTrackingFunction: + JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[EmittedType]]): + StateSpec[KeyType, ValueType, StateType, EmittedType] = { + val trackingFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => { + val t = javaTrackingFunction.call(time, k, JavaUtils.optionToOptional(v), s) + Option(t.orNull) + } + StateSpec.function(trackingFunc) + } + + /** + * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications + * of the `trackStateByKey` operation on a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]]. + * + * @param javaTrackingFunction The function applied on every data item to manage the associated + * state and generate the emitted data + * @tparam ValueType Class of the values + * @tparam StateType Class of the states data + * @tparam EmittedType Class of the emitted data + */ + def function[KeyType, ValueType, StateType, EmittedType]( + javaTrackingFunction: JFunction2[Optional[ValueType], State[StateType], EmittedType]): + StateSpec[KeyType, ValueType, StateType, EmittedType] = { + val trackingFunc = (v: Option[ValueType], s: State[StateType]) => { + javaTrackingFunction.call(Optional.fromNullable(v.get), s) + } + StateSpec.function(trackingFunc) + } } @@ -184,7 +225,6 @@ case class StateSpecImpl[K, V, S, T]( this } - override def numPartitions(numPartitions: Int): this.type = { this.partitioner(new HashPartitioner(numPartitions)) this diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index e2aec6c2f6..70e32b383e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -28,8 +28,10 @@ import com.google.common.base.Optional import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} + import org.apache.spark.Partitioner -import org.apache.spark.api.java.{JavaPairRDD, JavaUtils} +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.{JavaPairRDD, JavaSparkContext, JavaUtils} import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} @@ -426,6 +428,48 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( ) } + /** + * :: Experimental :: + * Return a new [[JavaDStream]] of data generated by combining the key-value data in `this` stream + * with a continuously updated per-key state. The user-provided state tracking function is + * applied on each keyed data item along with its corresponding state. The function can choose to + * update/remove the state and return a transformed data, which forms the + * [[JavaTrackStateDStream]]. + * + * The specifications of this transformation is made through the + * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there + * are a number of optional parameters - initial state data, number of partitions, timeouts, etc. + * See the [[org.apache.spark.streaming.StateSpec StateSpec]] for more details. + * + * Example of using `trackStateByKey`: + * {{{ + * // A tracking function that maintains an integer state and return a String + * Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc = + * new Function2<Optional<Integer>, State<Integer>, Optional<String>>() { + * + * @Override + * public Optional<String> call(Optional<Integer> one, State<Integer> state) { + * // Check if state exists, accordingly update/remove state and return transformed data + * } + * }; + * + * JavaTrackStateDStream<Integer, Integer, Integer, String> trackStateDStream = + * keyValueDStream.<Integer, String>trackStateByKey( + * StateSpec.function(trackStateFunc).numPartitions(10)); + * }}} + * + * @param spec Specification of this transformation + * @tparam StateType Class type of the state + * @tparam EmittedType Class type of the tranformed data return by the tracking function + */ + @Experimental + def trackStateByKey[StateType, EmittedType](spec: StateSpec[K, V, StateType, EmittedType]): + JavaTrackStateDStream[K, V, StateType, EmittedType] = { + new JavaTrackStateDStream(dstream.trackStateByKey(spec)( + JavaSparkContext.fakeClassTag, + JavaSparkContext.fakeClassTag)) + } + private def convertUpdateStateFunction[S](in: JFunction2[JList[V], Optional[S], Optional[S]]): (Seq[V], Option[S]) => Option[S] = { val scalaFunc: (Seq[V], Option[S]) => Option[S] = (values, state) => { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala new file mode 100644 index 0000000000..f459930d06 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.streaming.dstream.TrackStateDStream + +/** + * :: Experimental :: + * [[JavaDStream]] representing the stream of records emitted by the tracking function in the + * `trackStateByKey` operation on a [[JavaPairDStream]]. Additionally, it also gives access to the + * stream of state snapshots, that is, the state data of all keys after a batch has updated them. + * + * @tparam KeyType Class of the state key + * @tparam ValueType Class of the state value + * @tparam StateType Class of the state + * @tparam EmittedType Class of the emitted records + */ +@Experimental +class JavaTrackStateDStream[KeyType, ValueType, StateType, EmittedType]( + dstream: TrackStateDStream[KeyType, ValueType, StateType, EmittedType]) + extends JavaDStream[EmittedType](dstream)(JavaSparkContext.fakeClassTag) { + + def stateSnapshots(): JavaPairDStream[KeyType, StateType] = + new JavaPairDStream(dstream.stateSnapshots())( + JavaSparkContext.fakeClassTag, + JavaSparkContext.fakeClassTag) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala index 58d89c93bc..98e881e6ae 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala @@ -35,6 +35,7 @@ import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord} * all keys after a batch has updated them. * * @tparam KeyType Class of the state key + * @tparam ValueType Class of the state value * @tparam StateType Class of the state data * @tparam EmittedType Class of the emitted records */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala index ed7cea26d0..fc51496be4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala @@ -70,12 +70,14 @@ private[streaming] class TrackStateRDDPartition( * in the `prevStateRDD` to create `this` RDD * @param trackingFunction The function that will be used to update state and return new data * @param batchTime The time of the batch to which this RDD belongs to. Use to update + * @param timeoutThresholdTime The time to indicate which keys are timeout */ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]], private var partitionedDataRDD: RDD[(K, V)], trackingFunction: (Time, K, Option[V], State[S]) => Option[T], - batchTime: Time, timeoutThresholdTime: Option[Long] + batchTime: Time, + timeoutThresholdTime: Option[Long] ) extends RDD[TrackStateRDDRecord[K, S, T]]( partitionedDataRDD.sparkContext, List( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index ed622ef7bf..34287c3e00 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -267,7 +267,11 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( // Read the data of the delta val deltaMapSize = inputStream.readInt() - deltaMap = new OpenHashMap[K, StateInfo[S]]() + deltaMap = if (deltaMapSize != 0) { + new OpenHashMap[K, StateInfo[S]](deltaMapSize) + } else { + new OpenHashMap[K, StateInfo[S]](initialCapacity) + } var deltaMapCount = 0 while (deltaMapCount < deltaMapSize) { val key = inputStream.readObject().asInstanceOf[K] diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java new file mode 100644 index 0000000000..eac4cdd14a --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Set; + +import scala.Tuple2; + +import com.google.common.base.Optional; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.util.ManualClock; +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.HashPartitioner; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.Function4; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaTrackStateDStream; + +public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implements Serializable { + + /** + * This test is only for testing the APIs. It's not necessary to run it. + */ + public void testAPI() { + JavaPairRDD<String, Boolean> initialRDD = null; + JavaPairDStream<String, Integer> wordsDstream = null; + + final Function4<Time, String, Optional<Integer>, State<Boolean>, Optional<Double>> + trackStateFunc = + new Function4<Time, String, Optional<Integer>, State<Boolean>, Optional<Double>>() { + + @Override + public Optional<Double> call( + Time time, String word, Optional<Integer> one, State<Boolean> state) { + // Use all State's methods here + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return Optional.of(2.0); + } + }; + + JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream = + wordsDstream.trackStateByKey( + StateSpec.function(trackStateFunc) + .initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); + + JavaPairDStream<String, Boolean> emittedRecords = stateDstream.stateSnapshots(); + + final Function2<Optional<Integer>, State<Boolean>, Double> trackStateFunc2 = + new Function2<Optional<Integer>, State<Boolean>, Double>() { + + @Override + public Double call(Optional<Integer> one, State<Boolean> state) { + // Use all State's methods here + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return 2.0; + } + }; + + JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream2 = + wordsDstream.trackStateByKey( + StateSpec.<String, Integer, Boolean, Double> function(trackStateFunc2) + .initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); + + JavaPairDStream<String, Boolean> emittedRecords2 = stateDstream2.stateSnapshots(); + } + + @Test + public void testBasicFunction() { + List<List<String>> inputData = Arrays.asList( + Collections.<String>emptyList(), + Arrays.asList("a"), + Arrays.asList("a", "b"), + Arrays.asList("a", "b", "c"), + Arrays.asList("a", "b"), + Arrays.asList("a"), + Collections.<String>emptyList() + ); + + List<Set<Integer>> outputData = Arrays.asList( + Collections.<Integer>emptySet(), + Sets.newHashSet(1), + Sets.newHashSet(2, 1), + Sets.newHashSet(3, 2, 1), + Sets.newHashSet(4, 3), + Sets.newHashSet(5), + Collections.<Integer>emptySet() + ); + + List<Set<Tuple2<String, Integer>>> stateData = Arrays.asList( + Collections.<Tuple2<String, Integer>>emptySet(), + Sets.newHashSet(new Tuple2<String, Integer>("a", 1)), + Sets.newHashSet(new Tuple2<String, Integer>("a", 2), new Tuple2<String, Integer>("b", 1)), + Sets.newHashSet( + new Tuple2<String, Integer>("a", 3), + new Tuple2<String, Integer>("b", 2), + new Tuple2<String, Integer>("c", 1)), + Sets.newHashSet( + new Tuple2<String, Integer>("a", 4), + new Tuple2<String, Integer>("b", 3), + new Tuple2<String, Integer>("c", 1)), + Sets.newHashSet( + new Tuple2<String, Integer>("a", 5), + new Tuple2<String, Integer>("b", 3), + new Tuple2<String, Integer>("c", 1)), + Sets.newHashSet( + new Tuple2<String, Integer>("a", 5), + new Tuple2<String, Integer>("b", 3), + new Tuple2<String, Integer>("c", 1)) + ); + + Function2<Optional<Integer>, State<Integer>, Integer> trackStateFunc = + new Function2<Optional<Integer>, State<Integer>, Integer>() { + + @Override + public Integer call(Optional<Integer> value, State<Integer> state) throws Exception { + int sum = value.or(0) + (state.exists() ? state.get() : 0); + state.update(sum); + return sum; + } + }; + testOperation( + inputData, + StateSpec.<String, Integer, Integer, Integer>function(trackStateFunc), + outputData, + stateData); + } + + private <K, S, T> void testOperation( + List<List<K>> input, + StateSpec<K, Integer, S, T> trackStateSpec, + List<Set<T>> expectedOutputs, + List<Set<Tuple2<K, S>>> expectedStateSnapshots) { + int numBatches = expectedOutputs.size(); + JavaDStream<K> inputStream = JavaTestUtils.attachTestInputStream(ssc, input, 2); + JavaTrackStateDStream<K, Integer, S, T> trackeStateStream = + JavaPairDStream.fromJavaDStream(inputStream.map(new Function<K, Tuple2<K, Integer>>() { + @Override + public Tuple2<K, Integer> call(K x) throws Exception { + return new Tuple2<K, Integer>(x, 1); + } + })).trackStateByKey(trackStateSpec); + + final List<Set<T>> collectedOutputs = + Collections.synchronizedList(Lists.<Set<T>>newArrayList()); + trackeStateStream.foreachRDD(new Function<JavaRDD<T>, Void>() { + @Override + public Void call(JavaRDD<T> rdd) throws Exception { + collectedOutputs.add(Sets.newHashSet(rdd.collect())); + return null; + } + }); + final List<Set<Tuple2<K, S>>> collectedStateSnapshots = + Collections.synchronizedList(Lists.<Set<Tuple2<K, S>>>newArrayList()); + trackeStateStream.stateSnapshots().foreachRDD(new Function<JavaPairRDD<K, S>, Void>() { + @Override + public Void call(JavaPairRDD<K, S> rdd) throws Exception { + collectedStateSnapshots.add(Sets.newHashSet(rdd.collect())); + return null; + } + }); + BatchCounter batchCounter = new BatchCounter(ssc.ssc()); + ssc.start(); + ((ManualClock) ssc.ssc().scheduler().clock()) + .advance(ssc.ssc().progressListener().batchDuration() * numBatches + 1); + batchCounter.waitUntilBatchesCompleted(numBatches, 10000); + + Assert.assertEquals(expectedOutputs, collectedOutputs); + Assert.assertEquals(expectedStateSnapshots, collectedStateSnapshots); + } +} -- GitLab