Skip to content
Snippets Groups Projects
Commit e4e46b20 authored by Tathagata Das's avatar Tathagata Das
Browse files

[SPARK-11681][STREAMING] Correctly update state timestamp even when state is not updated

Bug: Timestamp is not updated if there is data but the corresponding state is not updated. This is wrong, and timeout is defined as "no data for a while", not "not state update for a while".

Fix: Update timestamp when timestamp when timeout is specified, otherwise no need.
Also refactored the code for better testability and added unit tests.

Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes #9648 from tdas/SPARK-11681.
parent 7786f9cc
No related branches found
No related tags found
No related merge requests found
......@@ -32,8 +32,51 @@ import org.apache.spark._
* Record storing the keyed-state [[TrackStateRDD]]. Each record contains a [[StateMap]] and a
* sequence of records returned by the tracking function of `trackStateByKey`.
*/
private[streaming] case class TrackStateRDDRecord[K, S, T](
var stateMap: StateMap[K, S], var emittedRecords: Seq[T])
private[streaming] case class TrackStateRDDRecord[K, S, E](
var stateMap: StateMap[K, S], var emittedRecords: Seq[E])
private[streaming] object TrackStateRDDRecord {
def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
prevRecord: Option[TrackStateRDDRecord[K, S, E]],
dataIterator: Iterator[(K, V)],
updateFunction: (Time, K, Option[V], State[S]) => Option[E],
batchTime: Time,
timeoutThresholdTime: Option[Long],
removeTimedoutData: Boolean
): TrackStateRDDRecord[K, S, E] = {
// Create a new state map by cloning the previous one (if it exists) or by creating an empty one
val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }
val emittedRecords = new ArrayBuffer[E]
val wrappedState = new StateImpl[S]()
// Call the tracking function on each record in the data iterator, and accordingly
// update the states touched, and collect the data returned by the tracking function
dataIterator.foreach { case (key, value) =>
wrappedState.wrap(newStateMap.get(key))
val emittedRecord = updateFunction(batchTime, key, Some(value), wrappedState)
if (wrappedState.isRemoved) {
newStateMap.remove(key)
} else if (wrappedState.isUpdated || timeoutThresholdTime.isDefined) {
newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
}
emittedRecords ++= emittedRecord
}
// Get the timed out state records, call the tracking function on each and collect the
// data returned
if (removeTimedoutData && timeoutThresholdTime.isDefined) {
newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
wrappedState.wrapTiminoutState(state)
val emittedRecord = updateFunction(batchTime, key, None, wrappedState)
emittedRecords ++= emittedRecord
newStateMap.remove(key)
}
}
TrackStateRDDRecord(newStateMap, emittedRecords)
}
}
/**
* Partition of the [[TrackStateRDD]], which depends on corresponding partitions of prev state
......@@ -72,16 +115,16 @@ private[streaming] class TrackStateRDDPartition(
* @param batchTime The time of the batch to which this RDD belongs to. Use to update
* @param timeoutThresholdTime The time to indicate which keys are timeout
*/
private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]],
private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, E]],
private var partitionedDataRDD: RDD[(K, V)],
trackingFunction: (Time, K, Option[V], State[S]) => Option[T],
trackingFunction: (Time, K, Option[V], State[S]) => Option[E],
batchTime: Time,
timeoutThresholdTime: Option[Long]
) extends RDD[TrackStateRDDRecord[K, S, T]](
) extends RDD[TrackStateRDDRecord[K, S, E]](
partitionedDataRDD.sparkContext,
List(
new OneToOneDependency[TrackStateRDDRecord[K, S, T]](prevStateRDD),
new OneToOneDependency[TrackStateRDDRecord[K, S, E]](prevStateRDD),
new OneToOneDependency(partitionedDataRDD))
) {
......@@ -98,7 +141,7 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T:
}
override def compute(
partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, T]] = {
partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, E]] = {
val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition]
val prevStateRDDIterator = prevStateRDD.iterator(
......@@ -106,42 +149,16 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T:
val dataIterator = partitionedDataRDD.iterator(
stateRDDPartition.partitionedDataRDDPartition, context)
// Create a new state map by cloning the previous one (if it exists) or by creating an empty one
val newStateMap = if (prevStateRDDIterator.hasNext) {
prevStateRDDIterator.next().stateMap.copy()
} else {
new EmptyStateMap[K, S]()
}
val emittedRecords = new ArrayBuffer[T]
val wrappedState = new StateImpl[S]()
// Call the tracking function on each record in the data RDD partition, and accordingly
// update the states touched, and the data returned by the tracking function.
dataIterator.foreach { case (key, value) =>
wrappedState.wrap(newStateMap.get(key))
val emittedRecord = trackingFunction(batchTime, key, Some(value), wrappedState)
if (wrappedState.isRemoved) {
newStateMap.remove(key)
} else if (wrappedState.isUpdated) {
newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
}
emittedRecords ++= emittedRecord
}
// If the RDD is expected to be doing a full scan of all the data in the StateMap,
// then use this opportunity to filter out those keys that have timed out.
// For each of them call the tracking function.
if (doFullScan && timeoutThresholdTime.isDefined) {
newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
wrappedState.wrapTiminoutState(state)
val emittedRecord = trackingFunction(batchTime, key, None, wrappedState)
emittedRecords ++= emittedRecord
newStateMap.remove(key)
}
}
Iterator(TrackStateRDDRecord(newStateMap, emittedRecords))
val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None
val newRecord = TrackStateRDDRecord.updateRecordWithData(
prevRecord,
dataIterator,
trackingFunction,
batchTime,
timeoutThresholdTime,
removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
)
Iterator(newRecord)
}
override protected def getPartitions: Array[Partition] = {
......
......@@ -23,6 +23,7 @@ import scala.reflect.ClassTag
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.util.OpenHashMapBasedStateMap
import org.apache.spark.streaming.{Time, State}
import org.apache.spark.{HashPartitioner, SparkConf, SparkContext, SparkFunSuite}
......@@ -52,6 +53,131 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
assert(rdd.partitioner === Some(partitioner))
}
test("updating state and generating emitted data in TrackStateRecord") {
val initialTime = 1000L
val updatedTime = 2000L
val thresholdTime = 1500L
@volatile var functionCalled = false
/**
* Assert that applying given data on a prior record generates correct updated record, with
* correct state map and emitted data
*/
def assertRecordUpdate(
initStates: Iterable[Int],
data: Iterable[String],
expectedStates: Iterable[(Int, Long)],
timeoutThreshold: Option[Long] = None,
removeTimedoutData: Boolean = false,
expectedOutput: Iterable[Int] = None,
expectedTimingOutStates: Iterable[Int] = None,
expectedRemovedStates: Iterable[Int] = None
): Unit = {
val initialStateMap = new OpenHashMapBasedStateMap[String, Int]()
initStates.foreach { s => initialStateMap.put("key", s, initialTime) }
functionCalled = false
val record = TrackStateRDDRecord[String, Int, Int](initialStateMap, Seq.empty)
val dataIterator = data.map { v => ("key", v) }.iterator
val removedStates = new ArrayBuffer[Int]
val timingOutStates = new ArrayBuffer[Int]
/**
* Tracking function that updates/removes state based on instructions in the data, and
* return state (when instructed or when state is timing out).
*/
def testFunc(t: Time, key: String, data: Option[String], state: State[Int]): Option[Int] = {
functionCalled = true
assert(t.milliseconds === updatedTime, "tracking func called with wrong time")
data match {
case Some("noop") =>
None
case Some("get-state") =>
Some(state.getOption().getOrElse(-1))
case Some("update-state") =>
if (state.exists) state.update(state.get + 1) else state.update(0)
None
case Some("remove-state") =>
removedStates += state.get()
state.remove()
None
case None =>
assert(state.isTimingOut() === true, "State is not timing out when data = None")
timingOutStates += state.get()
None
case _ =>
fail("Unexpected test data")
}
}
val updatedRecord = TrackStateRDDRecord.updateRecordWithData[String, String, Int, Int](
Some(record), dataIterator, testFunc,
Time(updatedTime), timeoutThreshold, removeTimedoutData)
val updatedStateData = updatedRecord.stateMap.getAll().map { x => (x._2, x._3) }
assert(updatedStateData.toSet === expectedStates.toSet,
"states do not match after updating the TrackStateRecord")
assert(updatedRecord.emittedRecords.toSet === expectedOutput.toSet,
"emitted data do not match after updating the TrackStateRecord")
assert(timingOutStates.toSet === expectedTimingOutStates.toSet, "timing out states do not " +
"match those that were expected to do so while updating the TrackStateRecord")
assert(removedStates.toSet === expectedRemovedStates.toSet, "removed states do not " +
"match those that were expected to do so while updating the TrackStateRecord")
}
// No data, no state should be changed, function should not be called,
assertRecordUpdate(initStates = Nil, data = None, expectedStates = Nil)
assert(functionCalled === false)
assertRecordUpdate(initStates = Seq(0), data = None, expectedStates = Seq((0, initialTime)))
assert(functionCalled === false)
// Data present, function should be called irrespective of whether state exists
assertRecordUpdate(initStates = Seq(0), data = Seq("noop"),
expectedStates = Seq((0, initialTime)))
assert(functionCalled === true)
assertRecordUpdate(initStates = None, data = Some("noop"), expectedStates = None)
assert(functionCalled === true)
// Function called with right state data
assertRecordUpdate(initStates = None, data = Seq("get-state"),
expectedStates = None, expectedOutput = Seq(-1))
assertRecordUpdate(initStates = Seq(123), data = Seq("get-state"),
expectedStates = Seq((123, initialTime)), expectedOutput = Seq(123))
// Update state and timestamp, when timeout not present
assertRecordUpdate(initStates = Nil, data = Seq("update-state"),
expectedStates = Seq((0, updatedTime)))
assertRecordUpdate(initStates = Seq(0), data = Seq("update-state"),
expectedStates = Seq((1, updatedTime)))
// Remove state
assertRecordUpdate(initStates = Seq(345), data = Seq("remove-state"),
expectedStates = Nil, expectedRemovedStates = Seq(345))
// State strictly older than timeout threshold should be timed out
assertRecordUpdate(initStates = Seq(123), data = Nil,
timeoutThreshold = Some(initialTime), removeTimedoutData = true,
expectedStates = Seq((123, initialTime)), expectedTimingOutStates = Nil)
assertRecordUpdate(initStates = Seq(123), data = Nil,
timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
expectedStates = Nil, expectedTimingOutStates = Seq(123))
// State should not be timed out after it has received data
assertRecordUpdate(initStates = Seq(123), data = Seq("noop"),
timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
expectedStates = Seq((123, updatedTime)), expectedTimingOutStates = Nil)
assertRecordUpdate(initStates = Seq(123), data = Seq("remove-state"),
timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
expectedStates = Nil, expectedTimingOutStates = Nil, expectedRemovedStates = Seq(123))
}
test("states generated by TrackStateRDD") {
val initStates = Seq(("k1", 0), ("k2", 0))
val initTime = 123
......@@ -148,9 +274,8 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
val rdd7 = testStateUpdates( // should remove k2's state
rdd6, Seq(("k2", 2), ("k0", 2), ("k3", 1)), Set(("k3", 0, updateTime)))
val rdd8 = testStateUpdates(
rdd7, Seq(("k3", 2)), Set() //
)
val rdd8 = testStateUpdates( // should remove k3's state
rdd7, Seq(("k3", 2)), Set())
}
/** Assert whether the `trackStateByKey` operation generates expected results */
......@@ -176,7 +301,7 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
// Persist to make sure that it gets computed only once and we can track precisely how many
// state keys the computing touched
newStateRDD.persist()
newStateRDD.persist().count()
assertRDD(newStateRDD, expectedStates, expectedEmittedRecords)
newStateRDD
}
......@@ -188,7 +313,8 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
expectedEmittedRecords: Set[T]): Unit = {
val states = trackStateRDD.flatMap { _.stateMap.getAll() }.collect().toSet
val emittedRecords = trackStateRDD.flatMap { _.emittedRecords }.collect().toSet
assert(states === expectedStates, "states after track state operation were not as expected")
assert(states === expectedStates,
"states after track state operation were not as expected")
assert(emittedRecords === expectedEmittedRecords,
"emitted records after track state operation were not as expected")
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment