diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 0802a2ae48e47e79473a9fd7dea84b516573b48b..4758e40e41be55d0fe5b249667ad14029e838bd6 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -1007,10 +1007,11 @@ test_that("spark.randomForest", {
   model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16,
                               numTrees = 20, seed = 123)
   predictions <- collect(predict(model, data))
-  expect_equal(predictions$prediction, c(60.379, 61.096, 60.636, 62.258,
-                                         63.736, 64.296, 64.868, 64.300,
-                                         66.709, 67.697, 67.966, 67.252,
-                                         68.866, 69.593, 69.195, 69.658),
+  expect_equal(predictions$prediction, c(60.32820, 61.22315, 60.69025, 62.11070,
+                                         63.53160, 64.05470, 65.12710, 64.30450,
+                                         66.70910, 67.86125, 68.08700, 67.21865,
+                                         68.89275, 69.53180, 69.39640, 69.68250),
+
                tolerance = 1e-4)
   stats <- summary(model)
   expect_equal(stats$numTrees, 20)
diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
index 297524c943e1fe49bde58bbb5c38f0f5b3624b2d..a7e0075debedbe488c754a65c9210eed32fb8ae7 100644
--- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala
@@ -56,11 +56,14 @@ private[spark] object SamplingUtils {
       val rand = new XORShiftRandom(seed)
       while (input.hasNext) {
         val item = input.next()
+        l += 1
+        // There are k elements in the reservoir, and the l-th element has been
+        // consumed. It should be chosen with probability k/l. The expression
+        // below is a random long chosen uniformly from [0,l)
         val replacementIndex = (rand.nextDouble() * l).toLong
         if (replacementIndex < k) {
           reservoir(replacementIndex.toInt) = item
         }
-        l += 1
       }
       (reservoir, l)
     }
diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
index 667a4db6f7bb6aa94aa3c976a9fe3ad9b313bb26..55c5dd5e2460ddd1309401a53ca3c969d8a35cf7 100644
--- a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala
@@ -44,6 +44,19 @@ class SamplingUtilsSuite extends SparkFunSuite {
     assert(sample3.length === 10)
   }
 
+  test("SPARK-18678 reservoirSampleAndCount with tiny input") {
+    val input = Seq(0, 1)
+    val counts = new Array[Int](input.size)
+    for (i <- 0 until 500) {
+      val (samples, inputSize) = SamplingUtils.reservoirSampleAndCount(input.iterator, 1)
+      assert(inputSize === 2)
+      assert(samples.length === 1)
+      counts(samples.head) += 1
+    }
+    // If correct, should be true with prob ~ 0.99999707
+    assert(math.abs(counts(0) - counts(1)) <= 100)
+  }
+
   test("computeFraction") {
     // test that the computed fraction guarantees enough data points
     // in the sample with a failure rate <= 0.0001