diff --git a/core/src/main/scala/org/apache/spark/util/Benchmark.scala b/core/src/main/scala/org/apache/spark/util/Benchmark.scala index 3718542810d0be86ee3eda51ba278205798568c7..1fc0ad7a4d6d3d1e77bfdc307bc238587688fb49 100644 --- a/core/src/main/scala/org/apache/spark/util/Benchmark.scala +++ b/core/src/main/scala/org/apache/spark/util/Benchmark.scala @@ -42,7 +42,24 @@ private[spark] class Benchmark( outputPerIteration: Boolean = false) { val benchmarks = mutable.ArrayBuffer.empty[Benchmark.Case] + /** + * Adds a case to run when run() is called. The given function will be run for several + * iterations to collect timing statistics. + */ def addCase(name: String)(f: Int => Unit): Unit = { + addTimerCase(name) { timer => + timer.startTiming() + f(timer.iteration) + timer.stopTiming() + } + } + + /** + * Adds a case with manual timing control. When the function is run, timing does not start + * until timer.startTiming() is called within the given function. The corresponding + * timer.stopTiming() method must be called before the function returns. + */ + def addTimerCase(name: String)(f: Benchmark.Timer => Unit): Unit = { benchmarks += Benchmark.Case(name, f) } @@ -84,7 +101,34 @@ private[spark] class Benchmark( } private[spark] object Benchmark { - case class Case(name: String, fn: Int => Unit) + + /** + * Object available to benchmark code to control timing e.g. to exclude set-up time. + * + * @param iteration specifies this is the nth iteration of running the benchmark case + */ + class Timer(val iteration: Int) { + private var accumulatedTime: Long = 0L + private var timeStart: Long = 0L + + def startTiming(): Unit = { + assert(timeStart == 0L, "Already started timing.") + timeStart = System.nanoTime + } + + def stopTiming(): Unit = { + assert(timeStart != 0L, "Have not started timing.") + accumulatedTime += System.nanoTime - timeStart + timeStart = 0L + } + + def totalTime(): Long = { + assert(timeStart == 0L, "Have not stopped timing.") + accumulatedTime + } + } + + case class Case(name: String, fn: Timer => Unit) case class Result(avgMs: Double, bestRate: Double, bestMs: Double) /** @@ -123,15 +167,12 @@ private[spark] object Benchmark { * Runs a single function `f` for iters, returning the average time the function took and * the rate of the function. */ - def measure(num: Long, iters: Int, outputPerIteration: Boolean)(f: Int => Unit): Result = { + def measure(num: Long, iters: Int, outputPerIteration: Boolean)(f: Timer => Unit): Result = { val runTimes = ArrayBuffer[Long]() for (i <- 0 until iters + 1) { - val start = System.nanoTime() - - f(i) - - val end = System.nanoTime() - val runTime = end - start + val timer = new Benchmark.Timer(i) + f(timer) + val runTime = timer.totalTime() if (i > 0) { runTimes += runTime }