diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 1bd1741a713943499be7d3c700e91229f836998a..47b9c6962f951033bf5e3cf696b2179cd3163c91 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -440,6 +440,23 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( cogroup(other1, other2, defaultPartitioner(self, other1, other2)) } + /** + * Return an RDD with the pairs from `this` whose keys are not in `other`. + * + * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting + * RDD will be <= us. + */ + def subtractByKey(other: RDD[(K, V)]): RDD[(K, V)] = + subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size))) + + /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ + def subtractByKey(other: RDD[(K, V)], numPartitions: Int): RDD[(K, V)] = + subtractByKey(other, new HashPartitioner(numPartitions)) + + /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ + def subtractByKey(other: RDD[(K, V)], p: Partitioner): RDD[(K, V)] = + new SubtractedRDD[K, V](self, other, p) + /** * Return the list of values in the RDD for key `key`. This operation is done efficiently if the * RDD has a known partitioner by only searching the partition that the key maps to. @@ -639,8 +656,6 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( } }, true) } - - // def subtractByKey(other: RDD[K]): RDD[(K,V)] = subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size))) } private[spark] diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 3451136fd40eb3dd439b66da508c45fd4d42883f..9bd8a0f98daa8a47c2b625f68eb1af4b02c0629a 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -408,24 +408,8 @@ abstract class RDD[T: ClassManifest]( * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting * RDD will be <= us. */ - def subtract(other: RDD[T]): RDD[T] = { - // If we do have a partitioner, our T is really (K, V), and we'll need to - // unwrap the (T, null) that subtract does to get back to the K - val rdd = subtract(other, partitioner match { - case None => new HashPartitioner(partitions.size) - case Some(p) => new Partitioner() { - override def numPartitions = p.numPartitions - override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1) - } - }) - // Hacky, but if we did have a partitioner, we can keep using it - new RDD[T](rdd) { - override def getPartitions = rdd.partitions - override def getDependencies = rdd.dependencies - override def compute(split: Partition, context: TaskContext) = rdd.compute(split, context) - override val partitioner = RDD.this.partitioner - } - } + def subtract(other: RDD[T]): RDD[T] = + subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size))) /** * Return an RDD with the elements from `this` that are not in `other`. @@ -437,7 +421,21 @@ abstract class RDD[T: ClassManifest]( * Return an RDD with the elements from `this` that are not in `other`. */ def subtract(other: RDD[T], p: Partitioner): RDD[T] = { - new SubtractedRDD[T, Any](this.map((_, null)), other.map((_, null)), p).keys + if (partitioner == Some(p)) { + // Our partitioner knows how to handle T (which, since we have a partitioner, is + // really (K, V)) so make a new Partitioner that will de-tuple our fake tuples + val p2 = new Partitioner() { + override def numPartitions = p.numPartitions + override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1) + } + // Unfortunately, since we're making a new p2, we'll get ShuffleDependencies + // anyway, and when calling .keys, will not have a partitioner set, even though + // the SubtractedRDD will, thanks to p2's de-tupled partitioning, already be + // partitioned by the right/real keys (e.g. p). + this.map(x => (x, null)).subtractByKey(other.map((_, null)), p2).keys + } else { + this.map(x => (x, null)).subtractByKey(other.map((_, null)), p).keys + } } /** diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala index 1bc84f7e1ef82a2c1bb8a5b4bb50911bdea9b948..90488f13cce2c9fd4e169f12d7a2322b97a1d9e3 100644 --- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala @@ -87,7 +87,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest]( // the first dep is rdd1; add all values to the map integrate(partition.deps(0), t => getSeq(t._1) += t._2) // the second dep is rdd2; remove all of its keys - integrate(partition.deps(1), t => map.remove(t._1) ) + integrate(partition.deps(1), t => map.remove(t._1)) map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten } diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 8411291b2caa31e86f58bbf17f39fbf68020a669..731c45cca2fe6875474ea55a340cbe245403f2ac 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -272,13 +272,39 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { } // partitionBy so we have a narrow dependency val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p) - println(sc.runJob(a, (i: Iterator[(Int, String)]) => i.toList).toList) // more partitions/no partitioner so a shuffle dependency val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) val c = a.subtract(b) assert(c.collect().toSet === Set((1, "a"), (3, "c"))) + // Ideally we could keep the original partitioner... + assert(c.partitioner === None) + } + + test("subtractByKey") { + sc = new SparkContext("local", "test") + val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2) + val b = sc.parallelize(Array((2, "bb"), (3, "cc"), (4, "dd")), 4) + val c = a.subtractByKey(b) + assert(c.collect().toSet === Set((1, "a"), (1, "a"))) + assert(c.partitions.size === a.partitions.size) + } + + test("subtractByKey with narrow dependency") { + sc = new SparkContext("local", "test") + // use a deterministic partitioner + val p = new Partitioner() { + def numPartitions = 5 + def getPartition(key: Any) = key.asInstanceOf[Int] + } + // partitionBy so we have a narrow dependency + val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p) + // more partitions/no partitioner so a shuffle dependency + val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4) + val c = a.subtractByKey(b) + assert(c.collect().toSet === Set((1, "a"), (1, "a"))) assert(c.partitioner.get === p) } + } object ShuffleSuite {