From 7d8bb4df3a5f8078cd4e86cef5e3b0b728afd2bc Mon Sep 17 00:00:00 2001 From: Stephen Haberman <stephen@exigencecorp.com> Date: Thu, 14 Mar 2013 14:44:15 -0500 Subject: [PATCH] Allow subtractByKey's other argument to have a different value type. --- core/src/main/scala/spark/PairRDDFunctions.scala | 8 ++++---- core/src/main/scala/spark/rdd/SubtractedRDD.scala | 6 +++--- core/src/test/scala/spark/ShuffleSuite.scala | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 47b9c6962f..3d1b1ca268 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -446,16 +446,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting * RDD will be <= us. */ - def subtractByKey(other: RDD[(K, V)]): RDD[(K, V)] = + def subtractByKey[W: ClassManifest](other: RDD[(K, W)]): RDD[(K, V)] = subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size))) /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ - def subtractByKey(other: RDD[(K, V)], numPartitions: Int): RDD[(K, V)] = + def subtractByKey[W: ClassManifest](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] = subtractByKey(other, new HashPartitioner(numPartitions)) /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ - def subtractByKey(other: RDD[(K, V)], p: Partitioner): RDD[(K, V)] = - new SubtractedRDD[K, V](self, other, p) + def subtractByKey[W: ClassManifest](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] = + new SubtractedRDD[K, V, W](self, other, p) /** * Return the list of values in the RDD for key `key`. This operation is done efficiently if the diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala index 90488f13cc..2f8ff9bb34 100644 --- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala @@ -28,9 +28,9 @@ import spark.OneToOneDependency * you can use `rdd1`'s partitioner/partition size and not worry about running * out of memory because of the size of `rdd2`. */ -private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest]( +private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest]( @transient var rdd1: RDD[(K, V)], - @transient var rdd2: RDD[(K, V)], + @transient var rdd2: RDD[(K, W)], part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) { override def getDependencies: Seq[Dependency[_]] = { @@ -40,7 +40,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest]( new OneToOneDependency(rdd) } else { logInfo("Adding shuffle dependency with " + rdd) - new ShuffleDependency(rdd, part) + new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part) } } } diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 731c45cca2..2b2a90defa 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -283,7 +283,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { test("subtractByKey") { sc = new SparkContext("local", "test") val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2) - val b = sc.parallelize(Array((2, "bb"), (3, "cc"), (4, "dd")), 4) + val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4) val c = a.subtractByKey(b) assert(c.collect().toSet === Set((1, "a"), (1, "a"))) assert(c.partitions.size === a.partitions.size) -- GitLab