diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 84bd0f7ffdf6426bad73022c12f2a5331c1464a7..4d6a97e255d348c2d229f2b6acf373e50b0d4764 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -728,6 +728,26 @@ class SparkContext(
     conf.getOption("spark.home").orElse(Option(System.getenv("SPARK_HOME")))
   }
 
+  /**
+   * Support function for API backtraces.
+   */
+  def setCallSite(site: String) {
+    setLocalProperty("externalCallSite", site)
+  }
+
+  /**
+   * Support function for API backtraces.
+   */
+  def clearCallSite() {
+    setLocalProperty("externalCallSite", null)
+  }
+
+  private[spark] def getCallSite(): String = {
+    val callSite = getLocalProperty("externalCallSite")
+    if (callSite == null) return Utils.formatSparkCallSite
+    callSite
+  }
+
   /**
    * Run a function on a given set of partitions in an RDD and pass the results to the given
    * handler function. This is the main entry point for all actions in Spark. The allowLocal
@@ -740,7 +760,7 @@ class SparkContext(
       partitions: Seq[Int],
       allowLocal: Boolean,
       resultHandler: (Int, U) => Unit) {
-    val callSite = Utils.formatSparkCallSite
+    val callSite = getCallSite
     val cleanedFunc = clean(func)
     logInfo("Starting job: " + callSite)
     val start = System.nanoTime
@@ -824,7 +844,7 @@ class SparkContext(
       func: (TaskContext, Iterator[T]) => U,
       evaluator: ApproximateEvaluator[U, R],
       timeout: Long): PartialResult[R] = {
-    val callSite = Utils.formatSparkCallSite
+    val callSite = getCallSite
     logInfo("Starting job: " + callSite)
     val start = System.nanoTime
     val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout,
@@ -844,7 +864,7 @@ class SparkContext(
       resultFunc: => R): SimpleFutureAction[R] =
   {
     val cleanF = clean(processPartition)
-    val callSite = Utils.formatSparkCallSite
+    val callSite = getCallSite
     val waiter = dagScheduler.submitJob(
       rdd,
       (context: TaskContext, iter: Iterator[T]) => cleanF(iter),
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index 0680a065a3082d1df80083a924c1d93970105f94..5be5317f40e7eb2a0cc4b821a282b27333e0b84b 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -411,6 +411,20 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
    * changed at runtime.
    */
   def getConf: SparkConf = sc.getConf
+
+  /**
+   * Pass-through to SparkContext.setCallSite.  For API support only.
+   */
+  def setCallSite(site: String) {
+    sc.setCallSite(site)
+  }
+
+  /**
+   * Pass-through to SparkContext.setCallSite.  For API support only.
+   */
+  def clearCallSite() {
+    sc.clearCallSite()
+  }
 }
 
 object JavaSparkContext {
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 6a7b0f8a86b6d43826b76d099d3b69c9dc1ca655..3f41b662799873694958b2b699563c0627cab499 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -953,7 +953,7 @@ abstract class RDD[T: ClassTag](
   private var storageLevel: StorageLevel = StorageLevel.NONE
 
   /** Record user function generating this RDD. */
-  @transient private[spark] val origin = Utils.formatSparkCallSite
+  @transient private[spark] val origin = sc.getCallSite
 
   private[spark] def elementClassTag: ClassTag[T] = classTag[T]
 
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index f87923e6fa4ebde8c8e5bad3b4b8ae53b673a68d..6fb4a7b3be25dfaeea2a7bdca0512868ade70a58 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -23,6 +23,7 @@ import operator
 import os
 import sys
 import shlex
+import traceback
 from subprocess import Popen, PIPE
 from tempfile import NamedTemporaryFile
 from threading import Thread
@@ -39,6 +40,46 @@ from py4j.java_collections import ListConverter, MapConverter
 
 __all__ = ["RDD"]
 
+def _extract_concise_traceback():
+    tb = traceback.extract_stack()
+    if len(tb) == 0:
+        return "I'm lost!"
+    # HACK:  This function is in a file called 'rdd.py' in the top level of
+    # everything PySpark.  Just trim off the directory name and assume
+    # everything in that tree is PySpark guts.
+    file, line, module, what = tb[len(tb) - 1]
+    sparkpath = os.path.dirname(file)
+    first_spark_frame = len(tb) - 1
+    for i in range(0, len(tb)):
+        file, line, fun, what = tb[i]
+        if file.startswith(sparkpath):
+            first_spark_frame = i
+            break
+    if first_spark_frame == 0:
+        file, line, fun, what = tb[0]
+        return "%s at %s:%d" % (fun, file, line)
+    sfile, sline, sfun, swhat = tb[first_spark_frame]
+    ufile, uline, ufun, uwhat = tb[first_spark_frame-1]
+    return "%s at %s:%d" % (sfun, ufile, uline)
+
+_spark_stack_depth = 0
+
+class _JavaStackTrace(object):
+    def __init__(self, sc):
+        self._traceback = _extract_concise_traceback()
+        self._context = sc
+
+    def __enter__(self):
+        global _spark_stack_depth
+        if _spark_stack_depth == 0:
+            self._context._jsc.setCallSite(self._traceback)
+        _spark_stack_depth += 1
+
+    def __exit__(self, type, value, tb):
+        global _spark_stack_depth
+        _spark_stack_depth -= 1
+        if _spark_stack_depth == 0:
+            self._context._jsc.setCallSite(None)
 
 class RDD(object):
     """
@@ -401,7 +442,8 @@ class RDD(object):
         """
         Return a list that contains all of the elements in this RDD.
         """
-        bytesInJava = self._jrdd.collect().iterator()
+        with _JavaStackTrace(self.context) as st:
+          bytesInJava = self._jrdd.collect().iterator()
         return list(self._collect_iterator_through_file(bytesInJava))
 
     def _collect_iterator_through_file(self, iterator):
@@ -582,13 +624,14 @@ class RDD(object):
         # TODO(shivaram): Similar to the scala implementation, update the take 
         # method to scan multiple splits based on an estimate of how many elements 
         # we have per-split.
-        for partition in range(mapped._jrdd.splits().size()):
-            partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
-            partitionsToTake[0] = partition
-            iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
-            items.extend(mapped._collect_iterator_through_file(iterator))
-            if len(items) >= num:
-                break
+        with _JavaStackTrace(self.context) as st:
+            for partition in range(mapped._jrdd.splits().size()):
+                partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
+                partitionsToTake[0] = partition
+                iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
+                items.extend(mapped._collect_iterator_through_file(iterator))
+                if len(items) >= num:
+                    break
         return items[:num]
 
     def first(self):
@@ -765,9 +808,10 @@ class RDD(object):
                 yield outputSerializer.dumps(items)
         keyed = PipelinedRDD(self, add_shuffle_key)
         keyed._bypass_serializer = True
-        pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
-        partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
-                                                     id(partitionFunc))
+        with _JavaStackTrace(self.context) as st:
+            pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
+            partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
+                                                          id(partitionFunc))
         jrdd = pairRDD.partitionBy(partitioner).values()
         rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
         # This is required so that id(partitionFunc) remains unique, even if