Skip to content
Snippets Groups Projects
Commit 7d8bb4df authored by Stephen Haberman's avatar Stephen Haberman
Browse files

Allow subtractByKey's other argument to have a different value type.

parent 4632c45a
No related branches found
No related tags found
No related merge requests found
...@@ -446,16 +446,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( ...@@ -446,16 +446,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us. * RDD will be <= us.
*/ */
def subtractByKey(other: RDD[(K, V)]): RDD[(K, V)] = def subtractByKey[W: ClassManifest](other: RDD[(K, W)]): RDD[(K, V)] =
subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size))) subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size)))
/** Return an RDD with the pairs from `this` whose keys are not in `other`. */ /** Return an RDD with the pairs from `this` whose keys are not in `other`. */
def subtractByKey(other: RDD[(K, V)], numPartitions: Int): RDD[(K, V)] = def subtractByKey[W: ClassManifest](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] =
subtractByKey(other, new HashPartitioner(numPartitions)) subtractByKey(other, new HashPartitioner(numPartitions))
/** Return an RDD with the pairs from `this` whose keys are not in `other`. */ /** Return an RDD with the pairs from `this` whose keys are not in `other`. */
def subtractByKey(other: RDD[(K, V)], p: Partitioner): RDD[(K, V)] = def subtractByKey[W: ClassManifest](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] =
new SubtractedRDD[K, V](self, other, p) new SubtractedRDD[K, V, W](self, other, p)
/** /**
* Return the list of values in the RDD for key `key`. This operation is done efficiently if the * Return the list of values in the RDD for key `key`. This operation is done efficiently if the
......
...@@ -28,9 +28,9 @@ import spark.OneToOneDependency ...@@ -28,9 +28,9 @@ import spark.OneToOneDependency
* you can use `rdd1`'s partitioner/partition size and not worry about running * you can use `rdd1`'s partitioner/partition size and not worry about running
* out of memory because of the size of `rdd2`. * out of memory because of the size of `rdd2`.
*/ */
private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest]( private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest](
@transient var rdd1: RDD[(K, V)], @transient var rdd1: RDD[(K, V)],
@transient var rdd2: RDD[(K, V)], @transient var rdd2: RDD[(K, W)],
part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) { part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) {
override def getDependencies: Seq[Dependency[_]] = { override def getDependencies: Seq[Dependency[_]] = {
...@@ -40,7 +40,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest]( ...@@ -40,7 +40,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest](
new OneToOneDependency(rdd) new OneToOneDependency(rdd)
} else { } else {
logInfo("Adding shuffle dependency with " + rdd) logInfo("Adding shuffle dependency with " + rdd)
new ShuffleDependency(rdd, part) new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part)
} }
} }
} }
......
...@@ -283,7 +283,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { ...@@ -283,7 +283,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
test("subtractByKey") { test("subtractByKey") {
sc = new SparkContext("local", "test") sc = new SparkContext("local", "test")
val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2) val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2)
val b = sc.parallelize(Array((2, "bb"), (3, "cc"), (4, "dd")), 4) val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4)
val c = a.subtractByKey(b) val c = a.subtractByKey(b)
assert(c.collect().toSet === Set((1, "a"), (1, "a"))) assert(c.collect().toSet === Set((1, "a"), (1, "a")))
assert(c.partitions.size === a.partitions.size) assert(c.partitions.size === a.partitions.size)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment