diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index f87923e6fa4ebde8c8e5bad3b4b8ae53b673a68d..6fb4a7b3be25dfaeea2a7bdca0512868ade70a58 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -23,6 +23,7 @@ import operator
 import os
 import sys
 import shlex
+import traceback
 from subprocess import Popen, PIPE
 from tempfile import NamedTemporaryFile
 from threading import Thread
@@ -39,6 +40,46 @@ from py4j.java_collections import ListConverter, MapConverter
 
 __all__ = ["RDD"]
 
+def _extract_concise_traceback():
+    tb = traceback.extract_stack()
+    if len(tb) == 0:
+        return "I'm lost!"
+    # HACK:  This function is in a file called 'rdd.py' in the top level of
+    # everything PySpark.  Just trim off the directory name and assume
+    # everything in that tree is PySpark guts.
+    file, line, module, what = tb[len(tb) - 1]
+    sparkpath = os.path.dirname(file)
+    first_spark_frame = len(tb) - 1
+    for i in range(0, len(tb)):
+        file, line, fun, what = tb[i]
+        if file.startswith(sparkpath):
+            first_spark_frame = i
+            break
+    if first_spark_frame == 0:
+        file, line, fun, what = tb[0]
+        return "%s at %s:%d" % (fun, file, line)
+    sfile, sline, sfun, swhat = tb[first_spark_frame]
+    ufile, uline, ufun, uwhat = tb[first_spark_frame-1]
+    return "%s at %s:%d" % (sfun, ufile, uline)
+
+_spark_stack_depth = 0
+
+class _JavaStackTrace(object):
+    def __init__(self, sc):
+        self._traceback = _extract_concise_traceback()
+        self._context = sc
+
+    def __enter__(self):
+        global _spark_stack_depth
+        if _spark_stack_depth == 0:
+            self._context._jsc.setCallSite(self._traceback)
+        _spark_stack_depth += 1
+
+    def __exit__(self, type, value, tb):
+        global _spark_stack_depth
+        _spark_stack_depth -= 1
+        if _spark_stack_depth == 0:
+            self._context._jsc.setCallSite(None)
 
 class RDD(object):
     """
@@ -401,7 +442,8 @@ class RDD(object):
         """
         Return a list that contains all of the elements in this RDD.
         """
-        bytesInJava = self._jrdd.collect().iterator()
+        with _JavaStackTrace(self.context) as st:
+          bytesInJava = self._jrdd.collect().iterator()
         return list(self._collect_iterator_through_file(bytesInJava))
 
     def _collect_iterator_through_file(self, iterator):
@@ -582,13 +624,14 @@ class RDD(object):
         # TODO(shivaram): Similar to the scala implementation, update the take 
         # method to scan multiple splits based on an estimate of how many elements 
         # we have per-split.
-        for partition in range(mapped._jrdd.splits().size()):
-            partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
-            partitionsToTake[0] = partition
-            iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
-            items.extend(mapped._collect_iterator_through_file(iterator))
-            if len(items) >= num:
-                break
+        with _JavaStackTrace(self.context) as st:
+            for partition in range(mapped._jrdd.splits().size()):
+                partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
+                partitionsToTake[0] = partition
+                iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
+                items.extend(mapped._collect_iterator_through_file(iterator))
+                if len(items) >= num:
+                    break
         return items[:num]
 
     def first(self):
@@ -765,9 +808,10 @@ class RDD(object):
                 yield outputSerializer.dumps(items)
         keyed = PipelinedRDD(self, add_shuffle_key)
         keyed._bypass_serializer = True
-        pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
-        partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
-                                                     id(partitionFunc))
+        with _JavaStackTrace(self.context) as st:
+            pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
+            partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
+                                                          id(partitionFunc))
         jrdd = pairRDD.partitionBy(partitioner).values()
         rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
         # This is required so that id(partitionFunc) remains unique, even if