Skip to content
Snippets Groups Projects
Commit fec01664 authored by Tor Myklebust's avatar Tor Myklebust
Browse files

Make Python function/line appear in the UI.

parent d812aeec
No related branches found
No related tags found
No related merge requests found
...@@ -23,6 +23,7 @@ import operator ...@@ -23,6 +23,7 @@ import operator
import os import os
import sys import sys
import shlex import shlex
import traceback
from subprocess import Popen, PIPE from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from threading import Thread from threading import Thread
...@@ -39,6 +40,46 @@ from py4j.java_collections import ListConverter, MapConverter ...@@ -39,6 +40,46 @@ from py4j.java_collections import ListConverter, MapConverter
__all__ = ["RDD"] __all__ = ["RDD"]
def _extract_concise_traceback():
tb = traceback.extract_stack()
if len(tb) == 0:
return "I'm lost!"
# HACK: This function is in a file called 'rdd.py' in the top level of
# everything PySpark. Just trim off the directory name and assume
# everything in that tree is PySpark guts.
file, line, module, what = tb[len(tb) - 1]
sparkpath = os.path.dirname(file)
first_spark_frame = len(tb) - 1
for i in range(0, len(tb)):
file, line, fun, what = tb[i]
if file.startswith(sparkpath):
first_spark_frame = i
break
if first_spark_frame == 0:
file, line, fun, what = tb[0]
return "%s at %s:%d" % (fun, file, line)
sfile, sline, sfun, swhat = tb[first_spark_frame]
ufile, uline, ufun, uwhat = tb[first_spark_frame-1]
return "%s at %s:%d" % (sfun, ufile, uline)
_spark_stack_depth = 0
class _JavaStackTrace(object):
def __init__(self, sc):
self._traceback = _extract_concise_traceback()
self._context = sc
def __enter__(self):
global _spark_stack_depth
if _spark_stack_depth == 0:
self._context._jsc.setCallSite(self._traceback)
_spark_stack_depth += 1
def __exit__(self, type, value, tb):
global _spark_stack_depth
_spark_stack_depth -= 1
if _spark_stack_depth == 0:
self._context._jsc.setCallSite(None)
class RDD(object): class RDD(object):
""" """
...@@ -401,7 +442,8 @@ class RDD(object): ...@@ -401,7 +442,8 @@ class RDD(object):
""" """
Return a list that contains all of the elements in this RDD. Return a list that contains all of the elements in this RDD.
""" """
bytesInJava = self._jrdd.collect().iterator() with _JavaStackTrace(self.context) as st:
bytesInJava = self._jrdd.collect().iterator()
return list(self._collect_iterator_through_file(bytesInJava)) return list(self._collect_iterator_through_file(bytesInJava))
def _collect_iterator_through_file(self, iterator): def _collect_iterator_through_file(self, iterator):
...@@ -582,13 +624,14 @@ class RDD(object): ...@@ -582,13 +624,14 @@ class RDD(object):
# TODO(shivaram): Similar to the scala implementation, update the take # TODO(shivaram): Similar to the scala implementation, update the take
# method to scan multiple splits based on an estimate of how many elements # method to scan multiple splits based on an estimate of how many elements
# we have per-split. # we have per-split.
for partition in range(mapped._jrdd.splits().size()): with _JavaStackTrace(self.context) as st:
partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1) for partition in range(mapped._jrdd.splits().size()):
partitionsToTake[0] = partition partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator() partitionsToTake[0] = partition
items.extend(mapped._collect_iterator_through_file(iterator)) iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
if len(items) >= num: items.extend(mapped._collect_iterator_through_file(iterator))
break if len(items) >= num:
break
return items[:num] return items[:num]
def first(self): def first(self):
...@@ -765,9 +808,10 @@ class RDD(object): ...@@ -765,9 +808,10 @@ class RDD(object):
yield outputSerializer.dumps(items) yield outputSerializer.dumps(items)
keyed = PipelinedRDD(self, add_shuffle_key) keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True keyed._bypass_serializer = True
pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() with _JavaStackTrace(self.context) as st:
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
id(partitionFunc)) partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
id(partitionFunc))
jrdd = pairRDD.partitionBy(partitioner).values() jrdd = pairRDD.partitionBy(partitioner).values()
rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer)) rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
# This is required so that id(partitionFunc) remains unique, even if # This is required so that id(partitionFunc) remains unique, even if
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment