diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
index fc51496be47bf188267b903710b490a659fa08b2..7050378d0feb09510ed7eac60d49a1c1c1c63085 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
@@ -32,8 +32,51 @@ import org.apache.spark._
  * Record storing the keyed-state [[TrackStateRDD]]. Each record contains a [[StateMap]] and a
  * sequence of records returned by the tracking function of `trackStateByKey`.
  */
-private[streaming] case class TrackStateRDDRecord[K, S, T](
-    var stateMap: StateMap[K, S], var emittedRecords: Seq[T])
+private[streaming] case class TrackStateRDDRecord[K, S, E](
+    var stateMap: StateMap[K, S], var emittedRecords: Seq[E])
+
+private[streaming] object TrackStateRDDRecord {
+  def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
+    prevRecord: Option[TrackStateRDDRecord[K, S, E]],
+    dataIterator: Iterator[(K, V)],
+    updateFunction: (Time, K, Option[V], State[S]) => Option[E],
+    batchTime: Time,
+    timeoutThresholdTime: Option[Long],
+    removeTimedoutData: Boolean
+  ): TrackStateRDDRecord[K, S, E] = {
+    // Create a new state map by cloning the previous one (if it exists) or by creating an empty one
+    val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }
+
+    val emittedRecords = new ArrayBuffer[E]
+    val wrappedState = new StateImpl[S]()
+
+    // Call the tracking function on each record in the data iterator, and accordingly
+    // update the states touched, and collect the data returned by the tracking function
+    dataIterator.foreach { case (key, value) =>
+      wrappedState.wrap(newStateMap.get(key))
+      val emittedRecord = updateFunction(batchTime, key, Some(value), wrappedState)
+      if (wrappedState.isRemoved) {
+        newStateMap.remove(key)
+      } else if (wrappedState.isUpdated || timeoutThresholdTime.isDefined) {
+        newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
+      }
+      emittedRecords ++= emittedRecord
+    }
+
+    // Get the timed out state records, call the tracking function on each and collect the
+    // data returned
+    if (removeTimedoutData && timeoutThresholdTime.isDefined) {
+      newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
+        wrappedState.wrapTiminoutState(state)
+        val emittedRecord = updateFunction(batchTime, key, None, wrappedState)
+        emittedRecords ++= emittedRecord
+        newStateMap.remove(key)
+      }
+    }
+
+    TrackStateRDDRecord(newStateMap, emittedRecords)
+  }
+}
 
 /**
  * Partition of the [[TrackStateRDD]], which depends on corresponding partitions of prev state
@@ -72,16 +115,16 @@ private[streaming] class TrackStateRDDPartition(
  * @param batchTime        The time of the batch to which this RDD belongs to. Use to update
  * @param timeoutThresholdTime The time to indicate which keys are timeout
  */
-private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
-    private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]],
+private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
+    private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, E]],
     private var partitionedDataRDD: RDD[(K, V)],
-    trackingFunction: (Time, K, Option[V], State[S]) => Option[T],
+    trackingFunction: (Time, K, Option[V], State[S]) => Option[E],
     batchTime: Time,
     timeoutThresholdTime: Option[Long]
-  ) extends RDD[TrackStateRDDRecord[K, S, T]](
+  ) extends RDD[TrackStateRDDRecord[K, S, E]](
     partitionedDataRDD.sparkContext,
     List(
-      new OneToOneDependency[TrackStateRDDRecord[K, S, T]](prevStateRDD),
+      new OneToOneDependency[TrackStateRDDRecord[K, S, E]](prevStateRDD),
       new OneToOneDependency(partitionedDataRDD))
   ) {
 
@@ -98,7 +141,7 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T:
   }
 
   override def compute(
-      partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, T]] = {
+      partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, E]] = {
 
     val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition]
     val prevStateRDDIterator = prevStateRDD.iterator(
@@ -106,42 +149,16 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T:
     val dataIterator = partitionedDataRDD.iterator(
       stateRDDPartition.partitionedDataRDDPartition, context)
 
-    // Create a new state map by cloning the previous one (if it exists) or by creating an empty one
-    val newStateMap = if (prevStateRDDIterator.hasNext) {
-      prevStateRDDIterator.next().stateMap.copy()
-    } else {
-      new EmptyStateMap[K, S]()
-    }
-
-    val emittedRecords = new ArrayBuffer[T]
-    val wrappedState = new StateImpl[S]()
-
-    // Call the tracking function on each record in the data RDD partition, and accordingly
-    // update the states touched, and the data returned by the tracking function.
-    dataIterator.foreach { case (key, value) =>
-      wrappedState.wrap(newStateMap.get(key))
-      val emittedRecord = trackingFunction(batchTime, key, Some(value), wrappedState)
-      if (wrappedState.isRemoved) {
-        newStateMap.remove(key)
-      } else if (wrappedState.isUpdated) {
-        newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
-      }
-      emittedRecords ++= emittedRecord
-    }
-
-    // If the RDD is expected to be doing a full scan of all the data in the StateMap,
-    // then use this opportunity to filter out those keys that have timed out.
-    // For each of them call the tracking function.
-    if (doFullScan && timeoutThresholdTime.isDefined) {
-      newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
-        wrappedState.wrapTiminoutState(state)
-        val emittedRecord = trackingFunction(batchTime, key, None, wrappedState)
-        emittedRecords ++= emittedRecord
-        newStateMap.remove(key)
-      }
-    }
-
-    Iterator(TrackStateRDDRecord(newStateMap, emittedRecords))
+    val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None
+    val newRecord = TrackStateRDDRecord.updateRecordWithData(
+      prevRecord,
+      dataIterator,
+      trackingFunction,
+      batchTime,
+      timeoutThresholdTime,
+      removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
+    )
+    Iterator(newRecord)
   }
 
   override protected def getPartitions: Array[Partition] = {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
index f396b76e8d2516165067eb55abe8b4efc9958beb..19ef5a14f8ab4674d59459e870999580403cdc7c 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
@@ -23,6 +23,7 @@ import scala.reflect.ClassTag
 import org.scalatest.BeforeAndAfterAll
 
 import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.util.OpenHashMapBasedStateMap
 import org.apache.spark.streaming.{Time, State}
 import org.apache.spark.{HashPartitioner, SparkConf, SparkContext, SparkFunSuite}
 
@@ -52,6 +53,131 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
     assert(rdd.partitioner === Some(partitioner))
   }
 
+  test("updating state and generating emitted data in TrackStateRecord") {
+
+    val initialTime = 1000L
+    val updatedTime = 2000L
+    val thresholdTime = 1500L
+    @volatile var functionCalled = false
+
+    /**
+     * Assert that applying given data on a prior record generates correct updated record, with
+     * correct state map and emitted data
+     */
+    def assertRecordUpdate(
+        initStates: Iterable[Int],
+        data: Iterable[String],
+        expectedStates: Iterable[(Int, Long)],
+        timeoutThreshold: Option[Long] = None,
+        removeTimedoutData: Boolean = false,
+        expectedOutput: Iterable[Int] = None,
+        expectedTimingOutStates: Iterable[Int] = None,
+        expectedRemovedStates: Iterable[Int] = None
+      ): Unit = {
+      val initialStateMap = new OpenHashMapBasedStateMap[String, Int]()
+      initStates.foreach { s => initialStateMap.put("key", s, initialTime) }
+      functionCalled = false
+      val record = TrackStateRDDRecord[String, Int, Int](initialStateMap, Seq.empty)
+      val dataIterator = data.map { v => ("key", v) }.iterator
+      val removedStates = new ArrayBuffer[Int]
+      val timingOutStates = new ArrayBuffer[Int]
+      /**
+       * Tracking function that updates/removes state based on instructions in the data, and
+       * return state (when instructed or when state is timing out).
+       */
+      def testFunc(t: Time, key: String, data: Option[String], state: State[Int]): Option[Int] = {
+        functionCalled = true
+
+        assert(t.milliseconds === updatedTime, "tracking func called with wrong time")
+
+        data match {
+          case Some("noop") =>
+            None
+          case Some("get-state") =>
+            Some(state.getOption().getOrElse(-1))
+          case Some("update-state") =>
+            if (state.exists) state.update(state.get + 1) else state.update(0)
+            None
+          case Some("remove-state") =>
+            removedStates += state.get()
+            state.remove()
+            None
+          case None =>
+            assert(state.isTimingOut() === true, "State is not timing out when data = None")
+            timingOutStates += state.get()
+            None
+          case _ =>
+            fail("Unexpected test data")
+        }
+      }
+
+      val updatedRecord = TrackStateRDDRecord.updateRecordWithData[String, String, Int, Int](
+        Some(record), dataIterator, testFunc,
+        Time(updatedTime), timeoutThreshold, removeTimedoutData)
+
+      val updatedStateData = updatedRecord.stateMap.getAll().map { x => (x._2, x._3) }
+      assert(updatedStateData.toSet === expectedStates.toSet,
+        "states do not match after updating the TrackStateRecord")
+
+      assert(updatedRecord.emittedRecords.toSet === expectedOutput.toSet,
+        "emitted data do not match after updating the TrackStateRecord")
+
+      assert(timingOutStates.toSet === expectedTimingOutStates.toSet, "timing out states do not " +
+        "match those that were expected to do so while updating the TrackStateRecord")
+
+      assert(removedStates.toSet === expectedRemovedStates.toSet, "removed states do not " +
+        "match those that were expected to do so while updating the TrackStateRecord")
+
+    }
+
+    // No data, no state should be changed, function should not be called,
+    assertRecordUpdate(initStates = Nil, data = None, expectedStates = Nil)
+    assert(functionCalled === false)
+    assertRecordUpdate(initStates = Seq(0), data = None, expectedStates = Seq((0, initialTime)))
+    assert(functionCalled === false)
+
+    // Data present, function should be called irrespective of whether state exists
+    assertRecordUpdate(initStates = Seq(0), data = Seq("noop"),
+      expectedStates = Seq((0, initialTime)))
+    assert(functionCalled === true)
+    assertRecordUpdate(initStates = None, data = Some("noop"), expectedStates = None)
+    assert(functionCalled === true)
+
+    // Function called with right state data
+    assertRecordUpdate(initStates = None, data = Seq("get-state"),
+      expectedStates = None, expectedOutput = Seq(-1))
+    assertRecordUpdate(initStates = Seq(123), data = Seq("get-state"),
+      expectedStates = Seq((123, initialTime)), expectedOutput = Seq(123))
+
+    // Update state and timestamp, when timeout not present
+    assertRecordUpdate(initStates = Nil, data = Seq("update-state"),
+      expectedStates = Seq((0, updatedTime)))
+    assertRecordUpdate(initStates = Seq(0), data = Seq("update-state"),
+      expectedStates = Seq((1, updatedTime)))
+
+    // Remove state
+    assertRecordUpdate(initStates = Seq(345), data = Seq("remove-state"),
+      expectedStates = Nil, expectedRemovedStates = Seq(345))
+
+    // State strictly older than timeout threshold should be timed out
+    assertRecordUpdate(initStates = Seq(123), data = Nil,
+      timeoutThreshold = Some(initialTime), removeTimedoutData = true,
+      expectedStates = Seq((123, initialTime)), expectedTimingOutStates = Nil)
+
+    assertRecordUpdate(initStates = Seq(123), data = Nil,
+      timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
+      expectedStates = Nil, expectedTimingOutStates = Seq(123))
+
+    // State should not be timed out after it has received data
+    assertRecordUpdate(initStates = Seq(123), data = Seq("noop"),
+      timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
+      expectedStates = Seq((123, updatedTime)), expectedTimingOutStates = Nil)
+    assertRecordUpdate(initStates = Seq(123), data = Seq("remove-state"),
+      timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
+      expectedStates = Nil, expectedTimingOutStates = Nil, expectedRemovedStates = Seq(123))
+
+  }
+
   test("states generated by TrackStateRDD") {
     val initStates = Seq(("k1", 0), ("k2", 0))
     val initTime = 123
@@ -148,9 +274,8 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
     val rdd7 = testStateUpdates(                      // should remove k2's state
       rdd6, Seq(("k2", 2), ("k0", 2), ("k3", 1)), Set(("k3", 0, updateTime)))
 
-    val rdd8 = testStateUpdates(
-      rdd7, Seq(("k3", 2)), Set()                     //
-    )
+    val rdd8 = testStateUpdates(                      // should remove k3's state
+      rdd7, Seq(("k3", 2)), Set())
   }
 
   /** Assert whether the `trackStateByKey` operation generates expected results */
@@ -176,7 +301,7 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
 
     // Persist to make sure that it gets computed only once and we can track precisely how many
     // state keys the computing touched
-    newStateRDD.persist()
+    newStateRDD.persist().count()
     assertRDD(newStateRDD, expectedStates, expectedEmittedRecords)
     newStateRDD
   }
@@ -188,7 +313,8 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
       expectedEmittedRecords: Set[T]): Unit = {
     val states = trackStateRDD.flatMap { _.stateMap.getAll() }.collect().toSet
     val emittedRecords = trackStateRDD.flatMap { _.emittedRecords }.collect().toSet
-    assert(states === expectedStates, "states after track state operation were not as expected")
+    assert(states === expectedStates,
+      "states after track state operation were not as expected")
     assert(emittedRecords === expectedEmittedRecords,
       "emitted records after track state operation were not as expected")
   }