diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 9f88340d0377828aad29a4cfc30b2b7b205bc1ae..1374f74968c9eacb57dfd92dbb62bbd42fca5be2 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1741,6 +1741,53 @@ class RDD(object): other._jrdd_deserializer) return RDD(pairRDD, self.ctx, deserializer) + def zipWithIndex(self): + """ + Zips this RDD with its element indices. + + The ordering is first based on the partition index and then the + ordering of items within each partition. So the first item in + the first partition gets index 0, and the last item in the last + partition receives the largest index. + + This method needs to trigger a spark job when this RDD contains + more than one partitions. + + >>> sc.parallelize(["a", "b", "c", "d"], 3).zipWithIndex().collect() + [('a', 0), ('b', 1), ('c', 2), ('d', 3)] + """ + starts = [0] + if self.getNumPartitions() > 1: + nums = self.mapPartitions(lambda it: [sum(1 for i in it)]).collect() + for i in range(len(nums) - 1): + starts.append(starts[-1] + nums[i]) + + def func(k, it): + for i, v in enumerate(it, starts[k]): + yield v, i + + return self.mapPartitionsWithIndex(func) + + def zipWithUniqueId(self): + """ + Zips this RDD with generated unique Long ids. + + Items in the kth partition will get ids k, n+k, 2*n+k, ..., where + n is the number of partitions. So there may exist gaps, but this + method won't trigger a spark job, which is different from + L{zipWithIndex} + + >>> sc.parallelize(["a", "b", "c", "d", "e"], 3).zipWithUniqueId().collect() + [('a', 0), ('b', 1), ('c', 4), ('d', 2), ('e', 5)] + """ + n = self.getNumPartitions() + + def func(k, it): + for i, v in enumerate(it): + yield v, i * n + k + + return self.mapPartitionsWithIndex(func) + def name(self): """ Return the name of this RDD.