Skip to content
Snippets Groups Projects
Commit 8237df80 authored by witgo's avatar witgo Committed by Patrick Wendell
Browse files

Avoid Option while generating call site

This is an update on https://github.com/apache/spark/pull/180, which changes the solution from blacklisting "Option.scala" to avoiding the Option code path while generating the call path.

Also includes a unit test to prevent this issue in the future, and some minor refactoring.

Thanks @witgo for reporting this issue and working on the initial solution!

Author: witgo <witgo@qq.com>
Author: Aaron Davidson <aaron@databricks.com>

Closes #222 from aarondav/180 and squashes the following commits:

f74aad1 [Aaron Davidson] Avoid Option while generating call site & add unit tests
d2b4980 [witgo] Modify the position of the filter
1bc22d7 [witgo] Fix Stage.name return "apply at Option.scala:120"
parent f8111eae
No related branches found
No related tags found
No related merge requests found
......@@ -877,7 +877,8 @@ class SparkContext(
* has overridden the call site, this will return the user's version.
*/
private[spark] def getCallSite(): String = {
Option(getLocalProperty("externalCallSite")).getOrElse(Utils.formatCallSiteInfo())
val defaultCallSite = Utils.getCallSiteInfo
Option(getLocalProperty("externalCallSite")).getOrElse(defaultCallSite.toString)
}
/**
......
......@@ -1041,7 +1041,7 @@ abstract class RDD[T: ClassTag](
/** User code that created this RDD (e.g. `textFile`, `parallelize`). */
@transient private[spark] val creationSiteInfo = Utils.getCallSiteInfo
private[spark] def getCreationSite = Utils.formatCallSiteInfo(creationSiteInfo)
private[spark] def getCreationSite: String = creationSiteInfo.toString
private[spark] def elementClassTag: ClassTag[T] = classTag[T]
......
......@@ -679,7 +679,13 @@ private[spark] object Utils extends Logging {
private val SPARK_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r
private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String,
val firstUserLine: Int, val firstUserClass: String)
val firstUserLine: Int, val firstUserClass: String) {
/** Returns a printable version of the call site info suitable for logs. */
override def toString = {
"%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine)
}
}
/**
* When called inside a class in the spark package, returns the name of the user code class
......@@ -687,8 +693,8 @@ private[spark] object Utils extends Logging {
* This is used, for example, to tell users where in their code each RDD got created.
*/
def getCallSiteInfo: CallSiteInfo = {
val trace = Thread.currentThread.getStackTrace().filter( el =>
(!el.getMethodName.contains("getStackTrace")))
val trace = Thread.currentThread.getStackTrace()
.filterNot(_.getMethodName.contains("getStackTrace"))
// Keep crawling up the stack trace until we find the first function not inside of the spark
// package. We track the last (shallowest) contiguous Spark method. This might be an RDD
......@@ -721,12 +727,6 @@ private[spark] object Utils extends Logging {
new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass)
}
/** Returns a printable version of the call site info suitable for logs. */
def formatCallSiteInfo(callSiteInfo: CallSiteInfo = Utils.getCallSiteInfo) = {
"%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile,
callSiteInfo.firstUserLine)
}
/** Return a string containing part of a file from byte 'start' to 'end'. */
def offsetBytes(path: String, start: Long, end: Long): String = {
val file = new File(path)
......
......@@ -17,7 +17,7 @@
package org.apache.spark
import org.scalatest.FunSuite
import org.scalatest.{Assertions, FunSuite}
class SparkContextInfoSuite extends FunSuite with LocalSparkContext {
test("getPersistentRDDs only returns RDDs that are marked as cached") {
......@@ -56,4 +56,38 @@ class SparkContextInfoSuite extends FunSuite with LocalSparkContext {
rdd.collect()
assert(sc.getRDDStorageInfo.size === 1)
}
test("call sites report correct locations") {
sc = new SparkContext("local", "test")
testPackage.runCallSiteTest(sc)
}
}
/** Call site must be outside of usual org.apache.spark packages (see Utils#SPARK_CLASS_REGEX). */
package object testPackage extends Assertions {
private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r
def runCallSiteTest(sc: SparkContext) {
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
val rddCreationSite = rdd.getCreationSite
val curCallSite = sc.getCallSite() // note: 2 lines after definition of "rdd"
val rddCreationLine = rddCreationSite match {
case CALL_SITE_REGEX(func, file, line) => {
assert(func === "makeRDD")
assert(file === "SparkContextInfoSuite.scala")
line.toInt
}
case _ => fail("Did not match expected call site format")
}
curCallSite match {
case CALL_SITE_REGEX(func, file, line) => {
assert(func === "getCallSite") // this is correct because we called it from outside of Spark
assert(file === "SparkContextInfoSuite.scala")
assert(line.toInt === rddCreationLine.toInt + 2)
}
case _ => fail("Did not match expected call site format")
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment