diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 9ae9305d4f02ea9a0461bf4be416f9d09ac78a8e..211918f5a05ec06e871ff2d8832e8ab008644444 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -271,6 +271,20 @@ class SparkContext(object): jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) return RDD(jrdd, self, serializer) + def pickleFile(self, name, minPartitions=None): + """ + Load an RDD previously saved using L{RDD.saveAsPickleFile} method. + + >>> tmpFile = NamedTemporaryFile(delete=True) + >>> tmpFile.close() + >>> sc.parallelize(range(10)).saveAsPickleFile(tmpFile.name, 5) + >>> sorted(sc.pickleFile(tmpFile.name, 3).collect()) + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + """ + minPartitions = minPartitions or self.defaultMinPartitions + return RDD(self._jsc.objectFile(name, minPartitions), self, + BatchedSerializer(PickleSerializer())) + def textFile(self, name, minPartitions=None): """ Read a text file from HDFS, a local file system (available on all diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 1b3c460dd621e79c492874bc937afbcbd779b8bf..ca0a95578fd2806119dd7e2bae2f18eeb6c497a7 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -33,7 +33,8 @@ import heapq from random import Random from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ - BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long + BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ + PickleSerializer, pack_long from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -427,11 +428,14 @@ class RDD(object): .filter(lambda x: (len(x[1][0]) != 0) and (len(x[1][1]) != 0)) \ .keys() - def _reserialize(self): - if self._jrdd_deserializer == self.ctx.serializer: + def _reserialize(self, serializer=None): + serializer = serializer or self.ctx.serializer + if self._jrdd_deserializer == serializer: return self else: - return self.map(lambda x: x, preservesPartitioning=True) + converted = self.map(lambda x: x, preservesPartitioning=True) + converted._jrdd_deserializer = serializer + return converted def __add__(self, other): """ @@ -897,6 +901,20 @@ class RDD(object): """ return self.take(1)[0] + def saveAsPickleFile(self, path, batchSize=10): + """ + Save this RDD as a SequenceFile of serialized objects. The serializer used is + L{pyspark.serializers.PickleSerializer}, default batch size is 10. + + >>> tmpFile = NamedTemporaryFile(delete=True) + >>> tmpFile.close() + >>> sc.parallelize([1, 2, 'spark', 'rdd']).saveAsPickleFile(tmpFile.name, 3) + >>> sorted(sc.pickleFile(tmpFile.name, 5).collect()) + [1, 2, 'rdd', 'spark'] + """ + self._reserialize(BatchedSerializer(PickleSerializer(), + batchSize))._jrdd.saveAsObjectFile(path) + def saveAsTextFile(self, path): """ Save this RDD as a text file, using string representations of elements. @@ -1421,10 +1439,9 @@ class PipelinedRDD(RDD): if self._jrdd_val: return self._jrdd_val if self._bypass_serializer: - serializer = NoOpSerializer() - else: - serializer = self.ctx.serializer - command = (self.func, self._prev_jrdd_deserializer, serializer) + self._jrdd_deserializer = NoOpSerializer() + command = (self.func, self._prev_jrdd_deserializer, + self._jrdd_deserializer) pickled_command = CloudPickleSerializer().dumps(command) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],