diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
index f14ff6705fd97c0502b7614e5a01c0e81acff693..cb782d27fe22cb61714c5f45a18057e2bc0cdd40 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
@@ -281,14 +281,20 @@ class DirectKafkaStreamSuite
       sendDataAndWaitForReceive(i)
     }
 
+    ssc.stop()
+
     // Verify that offset ranges were generated
-    val offsetRangesBeforeStop = getOffsetRanges(kafkaStream)
-    assert(offsetRangesBeforeStop.size >= 1, "No offset ranges generated")
+    // Since "offsetRangesAfterStop" will be used to compare with "recoveredOffsetRanges", we should
+    // collect offset ranges after stopping. Otherwise, because new RDDs keep being generated before
+    // stopping, we may not be able to get the latest RDDs, then "recoveredOffsetRanges" will
+    // contain something not in "offsetRangesAfterStop".
+    val offsetRangesAfterStop = getOffsetRanges(kafkaStream)
+    assert(offsetRangesAfterStop.size >= 1, "No offset ranges generated")
     assert(
-      offsetRangesBeforeStop.head._2.forall { _.fromOffset === 0 },
+      offsetRangesAfterStop.head._2.forall { _.fromOffset === 0 },
       "starting offset not zero"
     )
-    ssc.stop()
+
     logInfo("====== RESTARTING ========")
 
     // Recover context from checkpoints
@@ -298,12 +304,14 @@ class DirectKafkaStreamSuite
     // Verify offset ranges have been recovered
     val recoveredOffsetRanges = getOffsetRanges(recoveredStream)
     assert(recoveredOffsetRanges.size > 0, "No offset ranges recovered")
-    val earlierOffsetRangesAsSets = offsetRangesBeforeStop.map { x => (x._1, x._2.toSet) }
+    val earlierOffsetRangesAsSets = offsetRangesAfterStop.map { x => (x._1, x._2.toSet) }
     assert(
       recoveredOffsetRanges.forall { or =>
         earlierOffsetRangesAsSets.contains((or._1, or._2.toSet))
       },
-      "Recovered ranges are not the same as the ones generated"
+      "Recovered ranges are not the same as the ones generated\n" +
+        s"recoveredOffsetRanges: $recoveredOffsetRanges\n" +
+        s"earlierOffsetRangesAsSets: $earlierOffsetRangesAsSets"
     )
     // Restart context, give more data and verify the total at the end
     // If the total is write that means each records has been received only once