diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 12b4d94a567ceec4713299fa49a54349b08c0cef..132e4fb0d2cadcd2018437f239c1519cee957bbd 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -27,13 +27,12 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import org.apache.spark.broadcast.Broadcast import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.PipedRDD import org.apache.spark.util.Utils private[spark] class PythonRDD[T: ClassManifest]( parent: RDD[T], - command: Seq[String], + command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], preservePartitoning: Boolean, @@ -44,21 +43,10 @@ private[spark] class PythonRDD[T: ClassManifest]( val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - // Similar to Runtime.exec(), if we are given a single string, split it into words - // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, envVars: JMap[String, String], - pythonIncludes: JList[String], - preservePartitoning: Boolean, pythonExec: String, - broadcastVars: JList[Broadcast[Array[Byte]]], - accumulator: Accumulator[JList[Array[Byte]]]) = - this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec, - broadcastVars, accumulator) - override def getPartitions = parent.partitions override val partitioner = if (preservePartitoning) parent.partitioner else None - override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis val env = SparkEnv.get @@ -71,11 +59,10 @@ private[spark] class PythonRDD[T: ClassManifest]( SparkEnv.set(env) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) - val printOut = new PrintWriter(stream) // Partition index dataOut.writeInt(split.index) // sparkFilesDir - PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut) + dataOut.writeUTF(SparkFiles.getRootDirectory) // Broadcast variables dataOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { @@ -85,21 +72,16 @@ private[spark] class PythonRDD[T: ClassManifest]( } // Python includes (*.zip and *.egg files) dataOut.writeInt(pythonIncludes.length) - for (f <- pythonIncludes) { - PythonRDD.writeAsPickle(f, dataOut) - } + pythonIncludes.foreach(dataOut.writeUTF) dataOut.flush() - // Serialized user code - for (elem <- command) { - printOut.println(elem) - } - printOut.flush() + // Serialized command: + dataOut.writeInt(command.length) + dataOut.write(command) // Data values for (elem <- parent.iterator(split, context)) { - PythonRDD.writeAsPickle(elem, dataOut) + PythonRDD.writeToStream(elem, dataOut) } dataOut.flush() - printOut.flush() worker.shutdownOutput() } catch { case e: IOException => @@ -132,7 +114,7 @@ private[spark] class PythonRDD[T: ClassManifest]( val obj = new Array[Byte](length) stream.readFully(obj) obj - case -3 => + case SpecialLengths.TIMING_DATA => // Timing data from worker val bootTime = stream.readLong() val initTime = stream.readLong() @@ -143,24 +125,24 @@ private[spark] class PythonRDD[T: ClassManifest]( val total = finishTime - startTime logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish)) read - case -2 => + case SpecialLengths.PYTHON_EXCEPTION_THROWN => // Signals that an exception has been thrown in python val exLength = stream.readInt() val obj = new Array[Byte](exLength) stream.readFully(obj) throw new PythonException(new String(obj)) - case -1 => + case SpecialLengths.END_OF_DATA_SECTION => // We've finished the data section of the output, but we can still - // read some accumulator updates; let's do that, breaking when we - // get a negative length record. - var len2 = stream.readInt() - while (len2 >= 0) { - val update = new Array[Byte](len2) + // read some accumulator updates: + val numAccumulatorUpdates = stream.readInt() + (1 to numAccumulatorUpdates).foreach { _ => + val updateLen = stream.readInt() + val update = new Array[Byte](updateLen) stream.readFully(update) accumulator += Collections.singletonList(update) - len2 = stream.readInt() + } - new Array[Byte](0) + Array.empty[Byte] } } catch { case eof: EOFException => { @@ -197,62 +179,15 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this) } -private[spark] object PythonRDD { - - /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */ - def stripPickle(arr: Array[Byte]) : Array[Byte] = { - arr.slice(2, arr.length - 1) - } +private object SpecialLengths { + val END_OF_DATA_SECTION = -1 + val PYTHON_EXCEPTION_THROWN = -2 + val TIMING_DATA = -3 +} - /** - * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream. - * The data format is a 32-bit integer representing the pickled object's length (in bytes), - * followed by the pickled data. - * - * Pickle module: - * - * http://docs.python.org/2/library/pickle.html - * - * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules: - * - * http://hg.python.org/cpython/file/2.6/Lib/pickle.py - * http://hg.python.org/cpython/file/2.6/Lib/pickletools.py - * - * @param elem the object to write - * @param dOut a data output stream - */ - def writeAsPickle(elem: Any, dOut: DataOutputStream) { - if (elem.isInstanceOf[Array[Byte]]) { - val arr = elem.asInstanceOf[Array[Byte]] - dOut.writeInt(arr.length) - dOut.write(arr) - } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) { - val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]] - val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.write(PythonRDD.stripPickle(t._1)) - dOut.write(PythonRDD.stripPickle(t._2)) - dOut.writeByte(Pickle.TUPLE2) - dOut.writeByte(Pickle.STOP) - } else if (elem.isInstanceOf[String]) { - // For uniformity, strings are wrapped into Pickles. - val s = elem.asInstanceOf[String].getBytes("UTF-8") - val length = 2 + 1 + 4 + s.length + 1 - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.write(Pickle.BINUNICODE) - dOut.writeInt(Integer.reverseBytes(s.length)) - dOut.write(s) - dOut.writeByte(Pickle.STOP) - } else { - throw new SparkException("Unexpected RDD type") - } - } +private[spark] object PythonRDD { - def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) : + def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) val objs = new collection.mutable.ArrayBuffer[Array[Byte]] @@ -270,15 +205,32 @@ private[spark] object PythonRDD { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } - def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) { + def writeToStream(elem: Any, dataOut: DataOutputStream) { + elem match { + case bytes: Array[Byte] => + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + case pair: (Array[Byte], Array[Byte]) => + dataOut.writeInt(pair._1.length) + dataOut.write(pair._1) + dataOut.writeInt(pair._2.length) + dataOut.write(pair._2) + case str: String => + dataOut.writeUTF(str) + case other => + throw new SparkException("Unexpected element type " + other.getClass) + } + } + + def writeToFile[T](items: java.util.Iterator[T], filename: String) { import scala.collection.JavaConverters._ - writeIteratorToPickleFile(items.asScala, filename) + writeToFile(items.asScala, filename) } - def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) { + def writeToFile[T](items: Iterator[T], filename: String) { val file = new DataOutputStream(new FileOutputStream(filename)) for (item <- items) { - writeAsPickle(item, file) + writeToStream(item, file) } file.close() } @@ -289,17 +241,6 @@ private[spark] object PythonRDD { } } -private object Pickle { - val PROTO: Byte = 0x80.toByte - val TWO: Byte = 0x02.toByte - val BINUNICODE: Byte = 'X' - val STOP: Byte = '.' - val TUPLE2: Byte = 0x86.toByte - val EMPTY_LIST: Byte = ']' - val MARK: Byte = '(' - val APPENDS: Byte = 'e' -} - private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] { override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") } diff --git a/python/epydoc.conf b/python/epydoc.conf index 1d0d002d36623409f341688811d4d846ec6ee9fe..0b42e729f8dcc756c711584de2b2a4f071b480c5 100644 --- a/python/epydoc.conf +++ b/python/epydoc.conf @@ -32,6 +32,6 @@ target: docs/ private: no -exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers +exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test pyspark.rddsampler pyspark.daemon diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index da3d96689aa15dc14707e8f88c9170e51af3cede..2204e9c9ca7011f10ca60c900896db13868c35ff 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -90,9 +90,11 @@ import struct import SocketServer import threading from pyspark.cloudpickle import CloudPickler -from pyspark.serializers import read_int, read_with_length, load_pickle +from pyspark.serializers import read_int, PickleSerializer +pickleSer = PickleSerializer() + # Holds accumulators registered on the current machine, keyed by ID. This is then used to send # the local accumulator updates back to the driver program at the end of a task. _accumulatorRegistry = {} @@ -211,7 +213,7 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler): from pyspark.accumulators import _accumulatorRegistry num_updates = read_int(self.rfile) for _ in range(num_updates): - (aid, update) = load_pickle(read_with_length(self.rfile)) + (aid, update) = pickleSer._read_with_length(self.rfile) _accumulatorRegistry[aid] += update # Write a byte in acknowledgement self.wfile.write(struct.pack("!b", 1)) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a7ca8bc888c6759aff5784d26ad7df015d2fe2f4..cbd41e58c4a780392b2a6b8c58320535e416cd36 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -26,7 +26,7 @@ from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway -from pyspark.serializers import dump_pickle, write_with_length, batched +from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD @@ -42,7 +42,7 @@ class SparkContext(object): _gateway = None _jvm = None - _writeIteratorToPickleFile = None + _writeToFile = None _takePartition = None _next_accum_id = 0 _active_spark_context = None @@ -51,7 +51,7 @@ class SparkContext(object): def __init__(self, master, jobName, sparkHome=None, pyFiles=None, - environment=None, batchSize=1024): + environment=None, batchSize=1024, serializer=PickleSerializer()): """ Create a new SparkContext. @@ -67,6 +67,7 @@ class SparkContext(object): @param batchSize: The number of Python objects represented as a single Java object. Set 1 to disable batching or -1 to use an unlimited batch size. + @param serializer: The serializer for RDDs. >>> from pyspark.context import SparkContext @@ -83,7 +84,13 @@ class SparkContext(object): self.jobName = jobName self.sparkHome = sparkHome or None # None becomes null in Py4J self.environment = environment or {} - self.batchSize = batchSize # -1 represents a unlimited batch size + self._batchSize = batchSize # -1 represents an unlimited batch size + self._unbatched_serializer = serializer + if batchSize == 1: + self.serializer = self._unbatched_serializer + else: + self.serializer = BatchedSerializer(self._unbatched_serializer, + batchSize) # Create the Java SparkContext through Py4J empty_string_array = self._gateway.new_array(self._jvm.String, 0) @@ -125,8 +132,8 @@ class SparkContext(object): if not SparkContext._gateway: SparkContext._gateway = launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm - SparkContext._writeIteratorToPickleFile = \ - SparkContext._jvm.PythonRDD.writeIteratorToPickleFile + SparkContext._writeToFile = \ + SparkContext._jvm.PythonRDD.writeToFile SparkContext._takePartition = \ SparkContext._jvm.PythonRDD.takePartition @@ -184,15 +191,17 @@ class SparkContext(object): # Make sure we distribute data evenly if it's smaller than self.batchSize if "__len__" not in dir(c): c = list(c) # Make it a list so we can compute its length - batchSize = min(len(c) // numSlices, self.batchSize) + batchSize = min(len(c) // numSlices, self._batchSize) if batchSize > 1: - c = batched(c, batchSize) - for x in c: - write_with_length(dump_pickle(x), tempFile) + serializer = BatchedSerializer(self._unbatched_serializer, + batchSize) + else: + serializer = self._unbatched_serializer + serializer.dump_stream(c, tempFile) tempFile.close() - readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile - jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) - return RDD(jrdd, self) + readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile + jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) + return RDD(jrdd, self, serializer) def textFile(self, name, minSplits=None): """ @@ -201,21 +210,39 @@ class SparkContext(object): RDD of Strings. """ minSplits = minSplits or min(self.defaultParallelism, 2) - jrdd = self._jsc.textFile(name, minSplits) - return RDD(jrdd, self) + return RDD(self._jsc.textFile(name, minSplits), self, + MUTF8Deserializer()) - def _checkpointFile(self, name): + def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) - return RDD(jrdd, self) + return RDD(jrdd, self, input_deserializer) def union(self, rdds): """ Build the union of a list of RDDs. + + This supports unions() of RDDs with different serialized formats, + although this forces them to be reserialized using the default + serializer: + + >>> path = os.path.join(tempdir, "union-text.txt") + >>> with open(path, "w") as testFile: + ... testFile.write("Hello") + >>> textFile = sc.textFile(path) + >>> textFile.collect() + [u'Hello'] + >>> parallelized = sc.parallelize(["World!"]) + >>> sorted(sc.union([textFile, parallelized]).collect()) + [u'Hello', 'World!'] """ + first_jrdd_deserializer = rdds[0]._jrdd_deserializer + if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds): + rdds = [x._reserialize() for x in rdds] first = rdds[0]._jrdd rest = [x._jrdd for x in rdds[1:]] - rest = ListConverter().convert(rest, self.gateway._gateway_client) - return RDD(self._jsc.union(first, rest), self) + rest = ListConverter().convert(rest, self._gateway._gateway_client) + return RDD(self._jsc.union(first, rest), self, + rdds[0]._jrdd_deserializer) def broadcast(self, value): """ @@ -223,7 +250,9 @@ class SparkContext(object): object for reading it in distributed functions. The variable will be sent to each cluster only once. """ - jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) + pickleSer = PickleSerializer() + pickled = pickleSer.dumps(value) + jbroadcast = self._jsc.broadcast(bytearray(pickled)) return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) @@ -235,7 +264,7 @@ class SparkContext(object): and floating-point numbers if you do not provide one. For other types, a custom AccumulatorParam can be used. """ - if accum_param == None: + if accum_param is None: if isinstance(value, int): accum_param = accumulators.INT_ACCUMULATOR_PARAM elif isinstance(value, float): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7019fb8beefc8f5c8c2060899e2c28e8ae10643e..957f3f89c0eb2dda5afbda8ca3c9a850727308f7 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -18,7 +18,7 @@ from base64 import standard_b64encode as b64enc import copy from collections import defaultdict -from itertools import chain, ifilter, imap, product +from itertools import chain, ifilter, imap import operator import os import sys @@ -27,9 +27,8 @@ from subprocess import Popen, PIPE from tempfile import NamedTemporaryFile from threading import Thread -from pyspark import cloudpickle -from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \ - read_from_pickle_file, pack_long +from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ + BatchedSerializer, CloudPickleSerializer, pack_long from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -48,12 +47,12 @@ class RDD(object): operated on in parallel. """ - def __init__(self, jrdd, ctx): + def __init__(self, jrdd, ctx, jrdd_deserializer): self._jrdd = jrdd self.is_cached = False self.is_checkpointed = False self.ctx = ctx - self._partitionFunc = None + self._jrdd_deserializer = jrdd_deserializer @property def context(self): @@ -247,7 +246,23 @@ class RDD(object): >>> rdd.union(rdd).collect() [1, 1, 2, 3, 1, 1, 2, 3] """ - return RDD(self._jrdd.union(other._jrdd), self.ctx) + if self._jrdd_deserializer == other._jrdd_deserializer: + rdd = RDD(self._jrdd.union(other._jrdd), self.ctx, + self._jrdd_deserializer) + return rdd + else: + # These RDDs contain data in different serialized formats, so we + # must normalize them to the default serializer. + self_copy = self._reserialize() + other_copy = other._reserialize() + return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx, + self.ctx.serializer) + + def _reserialize(self): + if self._jrdd_deserializer == self.ctx.serializer: + return self + else: + return self.map(lambda x: x, preservesPartitioning=True) def __add__(self, other): """ @@ -334,17 +349,9 @@ class RDD(object): [(1, 1), (1, 2), (2, 1), (2, 2)] """ # Due to batching, we can't use the Java cartesian method. - java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx) - def unpack_batches(pair): - (x, y) = pair - if type(x) == Batch or type(y) == Batch: - xs = x.items if type(x) == Batch else [x] - ys = y.items if type(y) == Batch else [y] - for pair in product(xs, ys): - yield pair - else: - yield pair - return java_cartesian.flatMap(unpack_batches) + deserializer = CartesianDeserializer(self._jrdd_deserializer, + other._jrdd_deserializer) + return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer) def groupBy(self, f, numPartitions=None): """ @@ -391,8 +398,8 @@ class RDD(object): """ Return a list that contains all of the elements in this RDD. """ - picklesInJava = self._jrdd.collect().iterator() - return list(self._collect_iterator_through_file(picklesInJava)) + bytesInJava = self._jrdd.collect().iterator() + return list(self._collect_iterator_through_file(bytesInJava)) def _collect_iterator_through_file(self, iterator): # Transferring lots of data through Py4J can be slow because @@ -400,10 +407,10 @@ class RDD(object): # file and read it back. tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir) tempFile.close() - self.ctx._writeIteratorToPickleFile(iterator, tempFile.name) + self.ctx._writeToFile(iterator, tempFile.name) # Read the data into Python and deserialize it: with open(tempFile.name, 'rb') as tempFile: - for item in read_from_pickle_file(tempFile): + for item in self._jrdd_deserializer.load_stream(tempFile): yield item os.unlink(tempFile.name) @@ -571,7 +578,7 @@ class RDD(object): items = [] for partition in range(mapped._jrdd.splits().size()): iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition) - items.extend(self._collect_iterator_through_file(iterator)) + items.extend(mapped._collect_iterator_through_file(iterator)) if len(items) >= num: break return items[:num] @@ -735,6 +742,7 @@ class RDD(object): # Transferring O(n) objects to Java is too expensive. Instead, we'll # form the hash buckets in Python, transferring O(numPartitions) objects # to Java. Each object is a (splitNumber, [objects]) pair. + outputSerializer = self.ctx._unbatched_serializer def add_shuffle_key(split, iterator): buckets = defaultdict(list) @@ -743,14 +751,14 @@ class RDD(object): buckets[partitionFunc(k) % numPartitions].append((k, v)) for (split, items) in buckets.iteritems(): yield pack_long(split) - yield dump_pickle(Batch(items)) + yield outputSerializer.dumps(items) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, id(partitionFunc)) jrdd = pairRDD.partitionBy(partitioner).values() - rdd = RDD(jrdd, self.ctx) + rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer)) # This is required so that id(partitionFunc) remains unique, even if # partitionFunc is a lambda: rdd._partitionFunc = partitionFunc @@ -787,7 +795,8 @@ class RDD(object): numPartitions = self.ctx.defaultParallelism def combineLocally(iterator): combiners = {} - for (k, v) in iterator: + for x in iterator: + (k, v) = x if k not in combiners: combiners[k] = createCombiner(v) else: @@ -929,38 +938,39 @@ class PipelinedRDD(RDD): 20 """ def __init__(self, prev, func, preservesPartitioning=False): - if isinstance(prev, PipelinedRDD) and prev._is_pipelinable(): + if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable(): + # This transformation is the first in its stage: + self.func = func + self.preservesPartitioning = preservesPartitioning + self._prev_jrdd = prev._jrdd + self._prev_jrdd_deserializer = prev._jrdd_deserializer + else: prev_func = prev.func def pipeline_func(split, iterator): return func(split, prev_func(split, iterator)) self.func = pipeline_func self.preservesPartitioning = \ prev.preservesPartitioning and preservesPartitioning - self._prev_jrdd = prev._prev_jrdd - else: - self.func = func - self.preservesPartitioning = preservesPartitioning - self._prev_jrdd = prev._jrdd + self._prev_jrdd = prev._prev_jrdd # maintain the pipeline + self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer self.is_cached = False self.is_checkpointed = False self.ctx = prev.ctx self.prev = prev self._jrdd_val = None + self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False @property def _jrdd(self): if self._jrdd_val: return self._jrdd_val - func = self.func - if not self._bypass_serializer and self.ctx.batchSize != 1: - oldfunc = self.func - batchSize = self.ctx.batchSize - def batched_func(split, iterator): - return batched(oldfunc(split, iterator), batchSize) - func = batched_func - cmds = [func, self._bypass_serializer] - pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) + if self._bypass_serializer: + serializer = NoOpSerializer() + else: + serializer = self.ctx.serializer + command = (self.func, self._prev_jrdd_deserializer, serializer) + pickled_command = CloudPickleSerializer().dumps(command) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx._gateway._gateway_client) @@ -971,8 +981,9 @@ class PipelinedRDD(RDD): includes = ListConverter().convert(self.ctx._python_includes, self.ctx._gateway._gateway_client) python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), - pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, self.ctx._javaAccumulator, class_manifest) + bytearray(pickled_command), env, includes, self.preservesPartitioning, + self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator, + class_manifest) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 54fed1c9c70f66e503abb5c523d6327bb9bae8b4..811fa6f018b23f3c9883bd2a770f03c044786850 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -15,45 +15,269 @@ # limitations under the License. # -import struct +""" +PySpark supports custom serializers for transferring data; this can improve +performance. + +By default, PySpark uses L{PickleSerializer} to serialize objects using Python's +C{cPickle} serializer, which can serialize nearly any Python object. +Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be +faster. + +The serializer is chosen when creating L{SparkContext}: + +>>> from pyspark.context import SparkContext +>>> from pyspark.serializers import MarshalSerializer +>>> sc = SparkContext('local', 'test', serializer=MarshalSerializer()) +>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10) +[0, 2, 4, 6, 8, 10, 12, 14, 16, 18] +>>> sc.stop() + +By default, PySpark serialize objects in batches; the batch size can be +controlled through SparkContext's C{batchSize} parameter +(the default size is 1024 objects): + +>>> sc = SparkContext('local', 'test', batchSize=2) +>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) + +Behind the scenes, this creates a JavaRDD with four partitions, each of +which contains two batches of two objects: + +>>> rdd.glom().collect() +[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] +>>> rdd._jrdd.count() +8L +>>> sc.stop() + +A batch size of -1 uses an unlimited batch size, and a size of 1 disables +batching: + +>>> sc = SparkContext('local', 'test', batchSize=1) +>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) +>>> rdd.glom().collect() +[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] +>>> rdd._jrdd.count() +16L +""" + import cPickle +from itertools import chain, izip, product +import marshal +import struct +from pyspark import cloudpickle + + +__all__ = ["PickleSerializer", "MarshalSerializer"] + + +class SpecialLengths(object): + END_OF_DATA_SECTION = -1 + PYTHON_EXCEPTION_THROWN = -2 + TIMING_DATA = -3 + + +class Serializer(object): + + def dump_stream(self, iterator, stream): + """ + Serialize an iterator of objects to the output stream. + """ + raise NotImplementedError + + def load_stream(self, stream): + """ + Return an iterator of deserialized objects from the input stream. + """ + raise NotImplementedError + + + def _load_stream_without_unbatching(self, stream): + return self.load_stream(stream) + + # Note: our notion of "equality" is that output generated by + # equal serializers can be deserialized using the same serializer. + + # This default implementation handles the simple cases; + # subclasses should override __eq__ as appropriate. + + def __eq__(self, other): + return isinstance(other, self.__class__) + + def __ne__(self, other): + return not self.__eq__(other) + + +class FramedSerializer(Serializer): + """ + Serializer that writes objects as a stream of (length, data) pairs, + where C{length} is a 32-bit integer and data is C{length} bytes. + """ + + def dump_stream(self, iterator, stream): + for obj in iterator: + self._write_with_length(obj, stream) + + def load_stream(self, stream): + while True: + try: + yield self._read_with_length(stream) + except EOFError: + return + + def _write_with_length(self, obj, stream): + serialized = self.dumps(obj) + write_int(len(serialized), stream) + stream.write(serialized) + + def _read_with_length(self, stream): + length = read_int(stream) + obj = stream.read(length) + if obj == "": + raise EOFError + return self.loads(obj) + + def dumps(self, obj): + """ + Serialize an object into a byte array. + When batching is used, this will be called with an array of objects. + """ + raise NotImplementedError + + def loads(self, obj): + """ + Deserialize an object from a byte array. + """ + raise NotImplementedError + + +class BatchedSerializer(Serializer): + """ + Serializes a stream of objects in batches by calling its wrapped + Serializer with streams of objects. + """ + + UNLIMITED_BATCH_SIZE = -1 + + def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE): + self.serializer = serializer + self.batchSize = batchSize + + def _batched(self, iterator): + if self.batchSize == self.UNLIMITED_BATCH_SIZE: + yield list(iterator) + else: + items = [] + count = 0 + for item in iterator: + items.append(item) + count += 1 + if count == self.batchSize: + yield items + items = [] + count = 0 + if items: + yield items + + def dump_stream(self, iterator, stream): + self.serializer.dump_stream(self._batched(iterator), stream) + + def load_stream(self, stream): + return chain.from_iterable(self._load_stream_without_unbatching(stream)) + + def _load_stream_without_unbatching(self, stream): + return self.serializer.load_stream(stream) + + def __eq__(self, other): + return isinstance(other, BatchedSerializer) and \ + other.serializer == self.serializer + + def __str__(self): + return "BatchedSerializer<%s>" % str(self.serializer) -class Batch(object): +class CartesianDeserializer(FramedSerializer): """ - Used to store multiple RDD entries as a single Java object. + Deserializes the JavaRDD cartesian() of two PythonRDDs. + """ + + def __init__(self, key_ser, val_ser): + self.key_ser = key_ser + self.val_ser = val_ser + + def load_stream(self, stream): + key_stream = self.key_ser._load_stream_without_unbatching(stream) + val_stream = self.val_ser._load_stream_without_unbatching(stream) + key_is_batched = isinstance(self.key_ser, BatchedSerializer) + val_is_batched = isinstance(self.val_ser, BatchedSerializer) + for (keys, vals) in izip(key_stream, val_stream): + keys = keys if key_is_batched else [keys] + vals = vals if val_is_batched else [vals] + for pair in product(keys, vals): + yield pair + + def __eq__(self, other): + return isinstance(other, CartesianDeserializer) and \ + self.key_ser == other.key_ser and self.val_ser == other.val_ser + + def __str__(self): + return "CartesianDeserializer<%s, %s>" % \ + (str(self.key_ser), str(self.val_ser)) - This relieves us from having to explicitly track whether an RDD - is stored as batches of objects and avoids problems when processing - the union() of batched and unbatched RDDs (e.g. the union() of textFile() - with another RDD). + +class NoOpSerializer(FramedSerializer): + + def loads(self, obj): return obj + def dumps(self, obj): return obj + + +class PickleSerializer(FramedSerializer): """ - def __init__(self, items): - self.items = items + Serializes objects using Python's cPickle serializer: + http://docs.python.org/2/library/pickle.html -def batched(iterator, batchSize): - if batchSize == -1: # unlimited batch size - yield Batch(list(iterator)) - else: - items = [] - count = 0 - for item in iterator: - items.append(item) - count += 1 - if count == batchSize: - yield Batch(items) - items = [] - count = 0 - if items: - yield Batch(items) + This serializer supports nearly any Python object, but may + not be as fast as more specialized serializers. + """ + def dumps(self, obj): return cPickle.dumps(obj, 2) + loads = cPickle.loads -def dump_pickle(obj): - return cPickle.dumps(obj, 2) +class CloudPickleSerializer(PickleSerializer): + def dumps(self, obj): return cloudpickle.dumps(obj, 2) -load_pickle = cPickle.loads + +class MarshalSerializer(FramedSerializer): + """ + Serializes objects using Python's Marshal serializer: + + http://docs.python.org/2/library/marshal.html + + This serializer is faster than PickleSerializer but supports fewer datatypes. + """ + + dumps = marshal.dumps + loads = marshal.loads + + +class MUTF8Deserializer(Serializer): + """ + Deserializes streams written by Java's DataOutputStream.writeUTF(). + """ + + def loads(self, stream): + length = struct.unpack('>H', stream.read(2))[0] + return stream.read(length).decode('utf8') + + def load_stream(self, stream): + while True: + try: + yield self.loads(stream) + except struct.error: + return + except EOFError: + return def read_long(stream): @@ -84,25 +308,4 @@ def write_int(value, stream): def write_with_length(obj, stream): write_int(len(obj), stream) - stream.write(obj) - - -def read_with_length(stream): - length = read_int(stream) - obj = stream.read(length) - if obj == "": - raise EOFError - return obj - - -def read_from_pickle_file(stream): - try: - while True: - obj = load_pickle(read_with_length(stream)) - if type(obj) == Batch: # We don't care about inheritance - for item in obj.items: - yield item - else: - yield obj - except EOFError: - return + stream.write(obj) \ No newline at end of file diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 29d6a128f6a9b83dc742c676ef010a90f54ab73e..621e1cb58c3df10afa1f64a8f7b9f988dd71b0cb 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -86,7 +86,8 @@ class TestCheckpoint(PySparkTestCase): time.sleep(1) # 1 second self.assertTrue(flatMappedRDD.getCheckpointFile() is not None) - recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile()) + recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(), + flatMappedRDD._jrdd_deserializer) self.assertEquals([1, 2, 3, 4], recovered.collect()) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d63c2aaef772de62eef3bf913ad4a4859cf30512..f2b3f3c1421d12d48637ca47b3ac39ed61bbcfac 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,23 +23,22 @@ import sys import time import socket import traceback -from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles -from pyspark.serializers import write_with_length, read_with_length, write_int, \ - read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file +from pyspark.serializers import write_with_length, write_int, read_long, \ + write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer -def load_obj(infile): - return load_pickle(standard_b64decode(infile.readline().strip())) +pickleSer = PickleSerializer() +mutf8_deserializer = MUTF8Deserializer() def report_times(outfile, boot, init, finish): - write_int(-3, outfile) + write_int(SpecialLengths.TIMING_DATA, outfile) write_long(1000 * boot, outfile) write_long(1000 * init, outfile) write_long(1000 * finish, outfile) @@ -52,7 +51,7 @@ def main(infile, outfile): return # fetch name of workdir - spark_files_dir = load_pickle(read_with_length(infile)) + spark_files_dir = mutf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True @@ -60,38 +59,33 @@ def main(infile, outfile): num_broadcast_variables = read_int(infile) for _ in range(num_broadcast_variables): bid = read_long(infile) - value = read_with_length(infile) - _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) + value = pickleSer._read_with_length(infile) + _broadcastRegistry[bid] = Broadcast(bid, value) # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH sys.path.append(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): - sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile)))) + filename = mutf8_deserializer.loads(infile) + sys.path.append(os.path.join(spark_files_dir, filename)) - # now load function - func = load_obj(infile) - bypassSerializer = load_obj(infile) - if bypassSerializer: - dumps = lambda x: x - else: - dumps = dump_pickle + command = pickleSer._read_with_length(infile) + (func, deserializer, serializer) = command init_time = time.time() - iterator = read_from_pickle_file(infile) try: - for obj in func(split_index, iterator): - write_with_length(dumps(obj), outfile) + iterator = deserializer.load_stream(infile) + serializer.dump_stream(func(split_index, iterator), outfile) except Exception as e: - write_int(-2, outfile) + write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc(), outfile) sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) # Mark the beginning of the accumulators section of the output - write_int(-1, outfile) - for aid, accum in _accumulatorRegistry.items(): - write_with_length(dump_pickle((aid, accum._value)), outfile) - write_int(-1, outfile) + write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) + write_int(len(_accumulatorRegistry), outfile) + for (aid, accum) in _accumulatorRegistry.items(): + pickleSer._write_with_length((aid, accum._value), outfile) if __name__ == '__main__': diff --git a/python/run-tests b/python/run-tests index cbc554ea9db0d2cdd18323a408879d98f80810cb..d4dad672d284299a0a1fc116b24f7833eefe1b7b 100755 --- a/python/run-tests +++ b/python/run-tests @@ -37,6 +37,7 @@ run_test "pyspark/rdd.py" run_test "pyspark/context.py" run_test "-m doctest pyspark/broadcast.py" run_test "-m doctest pyspark/accumulators.py" +run_test "-m doctest pyspark/serializers.py" run_test "pyspark/tests.py" if [[ $FAILED != 0 ]]; then