diff --git a/core/src/main/scala/org/apache/spark/util/Benchmark.scala b/core/src/main/scala/org/apache/spark/util/Benchmark.scala
index 3718542810d0be86ee3eda51ba278205798568c7..1fc0ad7a4d6d3d1e77bfdc307bc238587688fb49 100644
--- a/core/src/main/scala/org/apache/spark/util/Benchmark.scala
+++ b/core/src/main/scala/org/apache/spark/util/Benchmark.scala
@@ -42,7 +42,24 @@ private[spark] class Benchmark(
     outputPerIteration: Boolean = false) {
   val benchmarks = mutable.ArrayBuffer.empty[Benchmark.Case]
 
+  /**
+   * Adds a case to run when run() is called. The given function will be run for several
+   * iterations to collect timing statistics.
+   */
   def addCase(name: String)(f: Int => Unit): Unit = {
+    addTimerCase(name) { timer =>
+      timer.startTiming()
+      f(timer.iteration)
+      timer.stopTiming()
+    }
+  }
+
+  /**
+   * Adds a case with manual timing control. When the function is run, timing does not start
+   * until timer.startTiming() is called within the given function. The corresponding
+   * timer.stopTiming() method must be called before the function returns.
+   */
+  def addTimerCase(name: String)(f: Benchmark.Timer => Unit): Unit = {
     benchmarks += Benchmark.Case(name, f)
   }
 
@@ -84,7 +101,34 @@ private[spark] class Benchmark(
 }
 
 private[spark] object Benchmark {
-  case class Case(name: String, fn: Int => Unit)
+
+  /**
+   * Object available to benchmark code to control timing e.g. to exclude set-up time.
+   *
+   * @param iteration specifies this is the nth iteration of running the benchmark case
+   */
+  class Timer(val iteration: Int) {
+    private var accumulatedTime: Long = 0L
+    private var timeStart: Long = 0L
+
+    def startTiming(): Unit = {
+      assert(timeStart == 0L, "Already started timing.")
+      timeStart = System.nanoTime
+    }
+
+    def stopTiming(): Unit = {
+      assert(timeStart != 0L, "Have not started timing.")
+      accumulatedTime += System.nanoTime - timeStart
+      timeStart = 0L
+    }
+
+    def totalTime(): Long = {
+      assert(timeStart == 0L, "Have not stopped timing.")
+      accumulatedTime
+    }
+  }
+
+  case class Case(name: String, fn: Timer => Unit)
   case class Result(avgMs: Double, bestRate: Double, bestMs: Double)
 
   /**
@@ -123,15 +167,12 @@ private[spark] object Benchmark {
    * Runs a single function `f` for iters, returning the average time the function took and
    * the rate of the function.
    */
-  def measure(num: Long, iters: Int, outputPerIteration: Boolean)(f: Int => Unit): Result = {
+  def measure(num: Long, iters: Int, outputPerIteration: Boolean)(f: Timer => Unit): Result = {
     val runTimes = ArrayBuffer[Long]()
     for (i <- 0 until iters + 1) {
-      val start = System.nanoTime()
-
-      f(i)
-
-      val end = System.nanoTime()
-      val runTime = end - start
+      val timer = new Benchmark.Timer(i)
+      f(timer)
+      val runTime = timer.totalTime()
       if (i > 0) {
         runTimes += runTime
       }