diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index eb0b0db0cc2efb9331b12c0e6e09778d115a7883..ef9bf4db9b1b1d60373a9fa6ca842b2bee055ef2 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -221,18 +221,6 @@ private[spark] object PythonRDD { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } - def writeStringAsPickle(elem: String, dOut: DataOutputStream) { - val s = elem.getBytes("UTF-8") - val length = 2 + 1 + 4 + s.length + 1 - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.write(Pickle.BINUNICODE) - dOut.writeInt(Integer.reverseBytes(s.length)) - dOut.write(s) - dOut.writeByte(Pickle.STOP) - } - def writeToStream(elem: Any, dataOut: DataOutputStream) { elem match { case bytes: Array[Byte] => @@ -244,9 +232,7 @@ private[spark] object PythonRDD { dataOut.writeInt(pair._2.length) dataOut.write(pair._2) case str: String => - // Until we've implemented full custom serializer support, we need to return - // strings as Pickles to properly support union() and cartesian(): - writeStringAsPickle(str, dataOut) + dataOut.writeUTF(str) case other => throw new SparkException("Unexpected element type " + other.getClass) } @@ -271,13 +257,6 @@ private[spark] object PythonRDD { } } -private object Pickle { - val PROTO: Byte = 0x80.toByte - val TWO: Byte = 0x02.toByte - val BINUNICODE: Byte = 'X' - val STOP: Byte = '.' -} - private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] { override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") } diff --git a/python/epydoc.conf b/python/epydoc.conf index 1d0d002d36623409f341688811d4d846ec6ee9fe..0b42e729f8dcc756c711584de2b2a4f071b480c5 100644 --- a/python/epydoc.conf +++ b/python/epydoc.conf @@ -32,6 +32,6 @@ target: docs/ private: no -exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers +exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test pyspark.rddsampler pyspark.daemon diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index da3d96689aa15dc14707e8f88c9170e51af3cede..2204e9c9ca7011f10ca60c900896db13868c35ff 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -90,9 +90,11 @@ import struct import SocketServer import threading from pyspark.cloudpickle import CloudPickler -from pyspark.serializers import read_int, read_with_length, load_pickle +from pyspark.serializers import read_int, PickleSerializer +pickleSer = PickleSerializer() + # Holds accumulators registered on the current machine, keyed by ID. This is then used to send # the local accumulator updates back to the driver program at the end of a task. _accumulatorRegistry = {} @@ -211,7 +213,7 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler): from pyspark.accumulators import _accumulatorRegistry num_updates = read_int(self.rfile) for _ in range(num_updates): - (aid, update) = load_pickle(read_with_length(self.rfile)) + (aid, update) = pickleSer._read_with_length(self.rfile) _accumulatorRegistry[aid] += update # Write a byte in acknowledgement self.wfile.write(struct.pack("!b", 1)) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 0fec1a6bf67d4d2afeeb06cd605c64f3055f722b..6bb1c6c3a1e4bc77e924bd6f7365ed981b4a73a4 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -26,7 +26,7 @@ from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway -from pyspark.serializers import dump_pickle, write_with_length, batched +from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD @@ -51,7 +51,7 @@ class SparkContext(object): def __init__(self, master, jobName, sparkHome=None, pyFiles=None, - environment=None, batchSize=1024): + environment=None, batchSize=1024, serializer=PickleSerializer()): """ Create a new SparkContext. @@ -67,6 +67,7 @@ class SparkContext(object): @param batchSize: The number of Python objects represented as a single Java object. Set 1 to disable batching or -1 to use an unlimited batch size. + @param serializer: The serializer for RDDs. >>> from pyspark.context import SparkContext @@ -83,7 +84,13 @@ class SparkContext(object): self.jobName = jobName self.sparkHome = sparkHome or None # None becomes null in Py4J self.environment = environment or {} - self.batchSize = batchSize # -1 represents a unlimited batch size + self._batchSize = batchSize # -1 represents an unlimited batch size + self._unbatched_serializer = serializer + if batchSize == 1: + self.serializer = self._unbatched_serializer + else: + self.serializer = BatchedSerializer(self._unbatched_serializer, + batchSize) # Create the Java SparkContext through Py4J empty_string_array = self._gateway.new_array(self._jvm.String, 0) @@ -184,15 +191,17 @@ class SparkContext(object): # Make sure we distribute data evenly if it's smaller than self.batchSize if "__len__" not in dir(c): c = list(c) # Make it a list so we can compute its length - batchSize = min(len(c) // numSlices, self.batchSize) + batchSize = min(len(c) // numSlices, self._batchSize) if batchSize > 1: - c = batched(c, batchSize) - for x in c: - write_with_length(dump_pickle(x), tempFile) + serializer = BatchedSerializer(self._unbatched_serializer, + batchSize) + else: + serializer = self._unbatched_serializer + serializer.dump_stream(c, tempFile) tempFile.close() readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) - return RDD(jrdd, self) + return RDD(jrdd, self, serializer) def textFile(self, name, minSplits=None): """ @@ -201,21 +210,39 @@ class SparkContext(object): RDD of Strings. """ minSplits = minSplits or min(self.defaultParallelism, 2) - jrdd = self._jsc.textFile(name, minSplits) - return RDD(jrdd, self) + return RDD(self._jsc.textFile(name, minSplits), self, + MUTF8Deserializer()) - def _checkpointFile(self, name): + def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) - return RDD(jrdd, self) + return RDD(jrdd, self, input_deserializer) def union(self, rdds): """ Build the union of a list of RDDs. + + This supports unions() of RDDs with different serialized formats, + although this forces them to be reserialized using the default + serializer: + + >>> path = os.path.join(tempdir, "union-text.txt") + >>> with open(path, "w") as testFile: + ... testFile.write("Hello") + >>> textFile = sc.textFile(path) + >>> textFile.collect() + [u'Hello'] + >>> parallelized = sc.parallelize(["World!"]) + >>> sorted(sc.union([textFile, parallelized]).collect()) + [u'Hello', 'World!'] """ + first_jrdd_deserializer = rdds[0]._jrdd_deserializer + if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds): + rdds = [x._reserialize() for x in rdds] first = rdds[0]._jrdd rest = [x._jrdd for x in rdds[1:]] - rest = ListConverter().convert(rest, self.gateway._gateway_client) - return RDD(self._jsc.union(first, rest), self) + rest = ListConverter().convert(rest, self._gateway._gateway_client) + return RDD(self._jsc.union(first, rest), self, + rdds[0]._jrdd_deserializer) def broadcast(self, value): """ @@ -223,7 +250,9 @@ class SparkContext(object): object for reading it in distributed functions. The variable will be sent to each cluster only once. """ - jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) + pickleSer = PickleSerializer() + pickled = pickleSer._dumps(value) + jbroadcast = self._jsc.broadcast(bytearray(pickled)) return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) @@ -235,7 +264,7 @@ class SparkContext(object): and floating-point numbers if you do not provide one. For other types, a custom AccumulatorParam can be used. """ - if accum_param == None: + if accum_param is None: if isinstance(value, int): accum_param = accumulators.INT_ACCUMULATOR_PARAM elif isinstance(value, float): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d3c4d13a1e0a28d5a66abd97442f67fba1686374..6691c30519db62e14ac661ae49d02336f5e92ada 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -18,7 +18,7 @@ from base64 import standard_b64encode as b64enc import copy from collections import defaultdict -from itertools import chain, ifilter, imap, product +from itertools import chain, ifilter, imap import operator import os import sys @@ -28,8 +28,8 @@ from tempfile import NamedTemporaryFile from threading import Thread from pyspark import cloudpickle -from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \ - read_from_pickle_file, pack_long +from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ + BatchedSerializer, pack_long from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -48,13 +48,12 @@ class RDD(object): operated on in parallel. """ - def __init__(self, jrdd, ctx): + def __init__(self, jrdd, ctx, jrdd_deserializer): self._jrdd = jrdd self.is_cached = False self.is_checkpointed = False self.ctx = ctx - self._partitionFunc = None - self._stage_input_is_pairs = False + self._jrdd_deserializer = jrdd_deserializer @property def context(self): @@ -248,7 +247,23 @@ class RDD(object): >>> rdd.union(rdd).collect() [1, 1, 2, 3, 1, 1, 2, 3] """ - return RDD(self._jrdd.union(other._jrdd), self.ctx) + if self._jrdd_deserializer == other._jrdd_deserializer: + rdd = RDD(self._jrdd.union(other._jrdd), self.ctx, + self._jrdd_deserializer) + return rdd + else: + # These RDDs contain data in different serialized formats, so we + # must normalize them to the default serializer. + self_copy = self._reserialize() + other_copy = other._reserialize() + return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx, + self.ctx.serializer) + + def _reserialize(self): + if self._jrdd_deserializer == self.ctx.serializer: + return self + else: + return self.map(lambda x: x, preservesPartitioning=True) def __add__(self, other): """ @@ -335,18 +350,9 @@ class RDD(object): [(1, 1), (1, 2), (2, 1), (2, 2)] """ # Due to batching, we can't use the Java cartesian method. - java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx) - def unpack_batches(pair): - (x, y) = pair - if type(x) == Batch or type(y) == Batch: - xs = x.items if type(x) == Batch else [x] - ys = y.items if type(y) == Batch else [y] - for pair in product(xs, ys): - yield pair - else: - yield pair - java_cartesian._stage_input_is_pairs = True - return java_cartesian.flatMap(unpack_batches) + deserializer = CartesianDeserializer(self._jrdd_deserializer, + other._jrdd_deserializer) + return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer) def groupBy(self, f, numPartitions=None): """ @@ -405,7 +411,7 @@ class RDD(object): self.ctx._writeToFile(iterator, tempFile.name) # Read the data into Python and deserialize it: with open(tempFile.name, 'rb') as tempFile: - for item in read_from_pickle_file(tempFile): + for item in self._jrdd_deserializer.load_stream(tempFile): yield item os.unlink(tempFile.name) @@ -573,7 +579,7 @@ class RDD(object): items = [] for partition in range(mapped._jrdd.splits().size()): iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition) - items.extend(self._collect_iterator_through_file(iterator)) + items.extend(mapped._collect_iterator_through_file(iterator)) if len(items) >= num: break return items[:num] @@ -737,6 +743,7 @@ class RDD(object): # Transferring O(n) objects to Java is too expensive. Instead, we'll # form the hash buckets in Python, transferring O(numPartitions) objects # to Java. Each object is a (splitNumber, [objects]) pair. + outputSerializer = self.ctx._unbatched_serializer def add_shuffle_key(split, iterator): buckets = defaultdict(list) @@ -745,14 +752,14 @@ class RDD(object): buckets[partitionFunc(k) % numPartitions].append((k, v)) for (split, items) in buckets.iteritems(): yield pack_long(split) - yield dump_pickle(Batch(items)) + yield outputSerializer._dumps(items) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, id(partitionFunc)) jrdd = pairRDD.partitionBy(partitioner).values() - rdd = RDD(jrdd, self.ctx) + rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer)) # This is required so that id(partitionFunc) remains unique, even if # partitionFunc is a lambda: rdd._partitionFunc = partitionFunc @@ -789,7 +796,8 @@ class RDD(object): numPartitions = self.ctx.defaultParallelism def combineLocally(iterator): combiners = {} - for (k, v) in iterator: + for x in iterator: + (k, v) = x if k not in combiners: combiners[k] = createCombiner(v) else: @@ -931,38 +939,38 @@ class PipelinedRDD(RDD): 20 """ def __init__(self, prev, func, preservesPartitioning=False): - if isinstance(prev, PipelinedRDD) and prev._is_pipelinable(): + if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable(): + # This transformation is the first in its stage: + self.func = func + self.preservesPartitioning = preservesPartitioning + self._prev_jrdd = prev._jrdd + self._prev_jrdd_deserializer = prev._jrdd_deserializer + else: prev_func = prev.func def pipeline_func(split, iterator): return func(split, prev_func(split, iterator)) self.func = pipeline_func self.preservesPartitioning = \ prev.preservesPartitioning and preservesPartitioning - self._prev_jrdd = prev._prev_jrdd - else: - self.func = func - self.preservesPartitioning = preservesPartitioning - self._prev_jrdd = prev._jrdd - self._stage_input_is_pairs = prev._stage_input_is_pairs + self._prev_jrdd = prev._prev_jrdd # maintain the pipeline + self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer self.is_cached = False self.is_checkpointed = False self.ctx = prev.ctx self.prev = prev self._jrdd_val = None + self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False @property def _jrdd(self): if self._jrdd_val: return self._jrdd_val - func = self.func - if not self._bypass_serializer and self.ctx.batchSize != 1: - oldfunc = self.func - batchSize = self.ctx.batchSize - def batched_func(split, iterator): - return batched(oldfunc(split, iterator), batchSize) - func = batched_func - cmds = [func, self._bypass_serializer, self._stage_input_is_pairs] + if self._bypass_serializer: + serializer = NoOpSerializer() + else: + serializer = self.ctx.serializer + cmds = [self.func, self._prev_jrdd_deserializer, serializer] pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index fd02e1ee8fbd7a8a83e2249f08da965f97703362..4fb444443f7849c7baa908b9b1b1adcbc3194318 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -15,8 +15,58 @@ # limitations under the License. # -import struct +""" +PySpark supports custom serializers for transferring data; this can improve +performance. + +By default, PySpark uses L{PickleSerializer} to serialize objects using Python's +C{cPickle} serializer, which can serialize nearly any Python object. +Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be +faster. + +The serializer is chosen when creating L{SparkContext}: + +>>> from pyspark.context import SparkContext +>>> from pyspark.serializers import MarshalSerializer +>>> sc = SparkContext('local', 'test', serializer=MarshalSerializer()) +>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10) +[0, 2, 4, 6, 8, 10, 12, 14, 16, 18] +>>> sc.stop() + +By default, PySpark serialize objects in batches; the batch size can be +controlled through SparkContext's C{batchSize} parameter +(the default size is 1024 objects): + +>>> sc = SparkContext('local', 'test', batchSize=2) +>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) + +Behind the scenes, this creates a JavaRDD with four partitions, each of +which contains two batches of two objects: + +>>> rdd.glom().collect() +[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] +>>> rdd._jrdd.count() +8L +>>> sc.stop() + +A batch size of -1 uses an unlimited batch size, and a size of 1 disables +batching: + +>>> sc = SparkContext('local', 'test', batchSize=1) +>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) +>>> rdd.glom().collect() +[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] +>>> rdd._jrdd.count() +16L +""" + import cPickle +from itertools import chain, izip, product +import marshal +import struct + + +__all__ = ["PickleSerializer", "MarshalSerializer"] class SpecialLengths(object): @@ -25,41 +75,206 @@ class SpecialLengths(object): TIMING_DATA = -3 -class Batch(object): +class Serializer(object): + + def dump_stream(self, iterator, stream): + """ + Serialize an iterator of objects to the output stream. + """ + raise NotImplementedError + + def load_stream(self, stream): + """ + Return an iterator of deserialized objects from the input stream. + """ + raise NotImplementedError + + + def _load_stream_without_unbatching(self, stream): + return self.load_stream(stream) + + # Note: our notion of "equality" is that output generated by + # equal serializers can be deserialized using the same serializer. + + # This default implementation handles the simple cases; + # subclasses should override __eq__ as appropriate. + + def __eq__(self, other): + return isinstance(other, self.__class__) + + def __ne__(self, other): + return not self.__eq__(other) + + +class FramedSerializer(Serializer): + """ + Serializer that writes objects as a stream of (length, data) pairs, + where C{length} is a 32-bit integer and data is C{length} bytes. + """ + + def dump_stream(self, iterator, stream): + for obj in iterator: + self._write_with_length(obj, stream) + + def load_stream(self, stream): + while True: + try: + yield self._read_with_length(stream) + except EOFError: + return + + def _write_with_length(self, obj, stream): + serialized = self._dumps(obj) + write_int(len(serialized), stream) + stream.write(serialized) + + def _read_with_length(self, stream): + length = read_int(stream) + obj = stream.read(length) + if obj == "": + raise EOFError + return self._loads(obj) + + def _dumps(self, obj): + """ + Serialize an object into a byte array. + When batching is used, this will be called with an array of objects. + """ + raise NotImplementedError + + def _loads(self, obj): + """ + Deserialize an object from a byte array. + """ + raise NotImplementedError + + +class BatchedSerializer(Serializer): + """ + Serializes a stream of objects in batches by calling its wrapped + Serializer with streams of objects. + """ + + UNLIMITED_BATCH_SIZE = -1 + + def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE): + self.serializer = serializer + self.batchSize = batchSize + + def _batched(self, iterator): + if self.batchSize == self.UNLIMITED_BATCH_SIZE: + yield list(iterator) + else: + items = [] + count = 0 + for item in iterator: + items.append(item) + count += 1 + if count == self.batchSize: + yield items + items = [] + count = 0 + if items: + yield items + + def dump_stream(self, iterator, stream): + if isinstance(iterator, basestring): + iterator = [iterator] + self.serializer.dump_stream(self._batched(iterator), stream) + + def load_stream(self, stream): + return chain.from_iterable(self._load_stream_without_unbatching(stream)) + + def _load_stream_without_unbatching(self, stream): + return self.serializer.load_stream(stream) + + def __eq__(self, other): + return isinstance(other, BatchedSerializer) and \ + other.serializer == self.serializer + + def __str__(self): + return "BatchedSerializer<%s>" % str(self.serializer) + + +class CartesianDeserializer(FramedSerializer): """ - Used to store multiple RDD entries as a single Java object. + Deserializes the JavaRDD cartesian() of two PythonRDDs. + """ + + def __init__(self, key_ser, val_ser): + self.key_ser = key_ser + self.val_ser = val_ser + + def load_stream(self, stream): + key_stream = self.key_ser._load_stream_without_unbatching(stream) + val_stream = self.val_ser._load_stream_without_unbatching(stream) + key_is_batched = isinstance(self.key_ser, BatchedSerializer) + val_is_batched = isinstance(self.val_ser, BatchedSerializer) + for (keys, vals) in izip(key_stream, val_stream): + keys = keys if key_is_batched else [keys] + vals = vals if val_is_batched else [vals] + for pair in product(keys, vals): + yield pair + + def __eq__(self, other): + return isinstance(other, CartesianDeserializer) and \ + self.key_ser == other.key_ser and self.val_ser == other.val_ser + + def __str__(self): + return "CartesianDeserializer<%s, %s>" % \ + (str(self.key_ser), str(self.val_ser)) + + +class NoOpSerializer(FramedSerializer): + + def _loads(self, obj): return obj + def _dumps(self, obj): return obj + + +class PickleSerializer(FramedSerializer): + """ + Serializes objects using Python's cPickle serializer: + + http://docs.python.org/2/library/pickle.html + + This serializer supports nearly any Python object, but may + not be as fast as more specialized serializers. + """ + + def _dumps(self, obj): return cPickle.dumps(obj, 2) + _loads = cPickle.loads + - This relieves us from having to explicitly track whether an RDD - is stored as batches of objects and avoids problems when processing - the union() of batched and unbatched RDDs (e.g. the union() of textFile() - with another RDD). +class MarshalSerializer(FramedSerializer): """ - def __init__(self, items): - self.items = items + Serializes objects using Python's Marshal serializer: + http://docs.python.org/2/library/marshal.html -def batched(iterator, batchSize): - if batchSize == -1: # unlimited batch size - yield Batch(list(iterator)) - else: - items = [] - count = 0 - for item in iterator: - items.append(item) - count += 1 - if count == batchSize: - yield Batch(items) - items = [] - count = 0 - if items: - yield Batch(items) + This serializer is faster than PickleSerializer but supports fewer datatypes. + """ + + _dumps = marshal.dumps + _loads = marshal.loads -def dump_pickle(obj): - return cPickle.dumps(obj, 2) +class MUTF8Deserializer(Serializer): + """ + Deserializes streams written by Java's DataOutputStream.writeUTF(). + """ + def _loads(self, stream): + length = struct.unpack('>H', stream.read(2))[0] + return stream.read(length).decode('utf8') -load_pickle = cPickle.loads + def load_stream(self, stream): + while True: + try: + yield self._loads(stream) + except struct.error: + return + except EOFError: + return def read_long(stream): @@ -90,43 +305,4 @@ def write_int(value, stream): def write_with_length(obj, stream): write_int(len(obj), stream) - stream.write(obj) - - -def read_mutf8(stream): - """ - Read a string written with Java's DataOutputStream.writeUTF() method. - """ - length = struct.unpack('>H', stream.read(2))[0] - return stream.read(length).decode('utf8') - - -def read_with_length(stream): - length = read_int(stream) - obj = stream.read(length) - if obj == "": - raise EOFError - return obj - - -def read_from_pickle_file(stream): - try: - while True: - obj = load_pickle(read_with_length(stream)) - if type(obj) == Batch: # We don't care about inheritance - for item in obj.items: - yield item - else: - yield obj - except EOFError: - return - - -def read_pairs_from_pickle_file(stream): - try: - while True: - a = load_pickle(read_with_length(stream)) - b = load_pickle(read_with_length(stream)) - yield (a, b) - except EOFError: - return \ No newline at end of file + stream.write(obj) \ No newline at end of file diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 29d6a128f6a9b83dc742c676ef010a90f54ab73e..621e1cb58c3df10afa1f64a8f7b9f988dd71b0cb 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -86,7 +86,8 @@ class TestCheckpoint(PySparkTestCase): time.sleep(1) # 1 second self.assertTrue(flatMappedRDD.getCheckpointFile() is not None) - recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile()) + recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(), + flatMappedRDD._jrdd_deserializer) self.assertEquals([1, 2, 3, 4], recovered.collect()) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 4e64557fc49e414657f14cdfe645807f7e5bffbd..5b16d5db7e859bd4b008063ca886565b6a73a2c6 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -30,13 +30,17 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles -from pyspark.serializers import write_with_length, read_with_length, write_int, \ - read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file, \ - SpecialLengths, read_mutf8, read_pairs_from_pickle_file +from pyspark.serializers import write_with_length, write_int, read_long, \ + write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer + + +pickleSer = PickleSerializer() +mutf8_deserializer = MUTF8Deserializer() def load_obj(infile): - return load_pickle(standard_b64decode(infile.readline().strip())) + decoded = standard_b64decode(infile.readline().strip()) + return pickleSer._loads(decoded) def report_times(outfile, boot, init, finish): @@ -53,7 +57,7 @@ def main(infile, outfile): return # fetch name of workdir - spark_files_dir = read_mutf8(infile) + spark_files_dir = mutf8_deserializer._loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True @@ -61,31 +65,24 @@ def main(infile, outfile): num_broadcast_variables = read_int(infile) for _ in range(num_broadcast_variables): bid = read_long(infile) - value = read_with_length(infile) - _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) + value = pickleSer._read_with_length(infile) + _broadcastRegistry[bid] = Broadcast(bid, value) # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH sys.path.append(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): - sys.path.append(os.path.join(spark_files_dir, read_mutf8(infile))) + filename = mutf8_deserializer._loads(infile) + sys.path.append(os.path.join(spark_files_dir, filename)) - # now load function + # Load this stage's function and serializer: func = load_obj(infile) - bypassSerializer = load_obj(infile) - stageInputIsPairs = load_obj(infile) - if bypassSerializer: - dumps = lambda x: x - else: - dumps = dump_pickle + deserializer = load_obj(infile) + serializer = load_obj(infile) init_time = time.time() - if stageInputIsPairs: - iterator = read_pairs_from_pickle_file(infile) - else: - iterator = read_from_pickle_file(infile) try: - for obj in func(split_index, iterator): - write_with_length(dumps(obj), outfile) + iterator = deserializer.load_stream(infile) + serializer.dump_stream(func(split_index, iterator), outfile) except Exception as e: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc(), outfile) @@ -96,7 +93,7 @@ def main(infile, outfile): write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): - write_with_length(dump_pickle((aid, accum._value)), outfile) + pickleSer._write_with_length((aid, accum._value), outfile) if __name__ == '__main__': diff --git a/python/run-tests b/python/run-tests index cbc554ea9db0d2cdd18323a408879d98f80810cb..d4dad672d284299a0a1fc116b24f7833eefe1b7b 100755 --- a/python/run-tests +++ b/python/run-tests @@ -37,6 +37,7 @@ run_test "pyspark/rdd.py" run_test "pyspark/context.py" run_test "-m doctest pyspark/broadcast.py" run_test "-m doctest pyspark/accumulators.py" +run_test "-m doctest pyspark/serializers.py" run_test "pyspark/tests.py" if [[ $FAILED != 0 ]]; then