From 7d8bb4df3a5f8078cd4e86cef5e3b0b728afd2bc Mon Sep 17 00:00:00 2001
From: Stephen Haberman <stephen@exigencecorp.com>
Date: Thu, 14 Mar 2013 14:44:15 -0500
Subject: [PATCH] Allow subtractByKey's other argument to have a different
 value type.

---
 core/src/main/scala/spark/PairRDDFunctions.scala  | 8 ++++----
 core/src/main/scala/spark/rdd/SubtractedRDD.scala | 6 +++---
 core/src/test/scala/spark/ShuffleSuite.scala      | 2 +-
 3 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 47b9c6962f..3d1b1ca268 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -446,16 +446,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
    * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
    * RDD will be <= us.
    */
-  def subtractByKey(other: RDD[(K, V)]): RDD[(K, V)] =
+  def subtractByKey[W: ClassManifest](other: RDD[(K, W)]): RDD[(K, V)] =
     subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size)))
 
   /** Return an RDD with the pairs from `this` whose keys are not in `other`. */
-  def subtractByKey(other: RDD[(K, V)], numPartitions: Int): RDD[(K, V)] =
+  def subtractByKey[W: ClassManifest](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] =
     subtractByKey(other, new HashPartitioner(numPartitions))
 
   /** Return an RDD with the pairs from `this` whose keys are not in `other`. */
-  def subtractByKey(other: RDD[(K, V)], p: Partitioner): RDD[(K, V)] =
-    new SubtractedRDD[K, V](self, other, p)
+  def subtractByKey[W: ClassManifest](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] =
+    new SubtractedRDD[K, V, W](self, other, p)
 
   /**
    * Return the list of values in the RDD for key `key`. This operation is done efficiently if the
diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
index 90488f13cc..2f8ff9bb34 100644
--- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
@@ -28,9 +28,9 @@ import spark.OneToOneDependency
  * you can use `rdd1`'s partitioner/partition size and not worry about running
  * out of memory because of the size of `rdd2`.
  */
-private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest](
+private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest](
     @transient var rdd1: RDD[(K, V)],
-    @transient var rdd2: RDD[(K, V)],
+    @transient var rdd2: RDD[(K, W)],
     part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) {
 
   override def getDependencies: Seq[Dependency[_]] = {
@@ -40,7 +40,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest](
 	    new OneToOneDependency(rdd)
 	  } else {
 	    logInfo("Adding shuffle dependency with " + rdd)
-	    new ShuffleDependency(rdd, part)
+	    new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part)
 	  }
     }
   }
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index 731c45cca2..2b2a90defa 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -283,7 +283,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
   test("subtractByKey") {
     sc = new SparkContext("local", "test")
     val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2)
-    val b = sc.parallelize(Array((2, "bb"), (3, "cc"), (4, "dd")), 4)
+    val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4)
     val c = a.subtractByKey(b)
     assert(c.collect().toSet === Set((1, "a"), (1, "a")))
     assert(c.partitions.size === a.partitions.size)
-- 
GitLab