diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index e7408e4352abfc2e53e20f91039e186ff72c139f..3d1b1ca268c8bdf251a276f6b46feb478f56645d 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -440,6 +440,23 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
     cogroup(other1, other2, defaultPartitioner(self, other1, other2))
   }
 
+  /**
+   * Return an RDD with the pairs from `this` whose keys are not in `other`.
+   * 
+   * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
+   * RDD will be <= us.
+   */
+  def subtractByKey[W: ClassManifest](other: RDD[(K, W)]): RDD[(K, V)] =
+    subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size)))
+
+  /** Return an RDD with the pairs from `this` whose keys are not in `other`. */
+  def subtractByKey[W: ClassManifest](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] =
+    subtractByKey(other, new HashPartitioner(numPartitions))
+
+  /** Return an RDD with the pairs from `this` whose keys are not in `other`. */
+  def subtractByKey[W: ClassManifest](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] =
+    new SubtractedRDD[K, V, W](self, other, p)
+
   /**
    * Return the list of values in the RDD for key `key`. This operation is done efficiently if the
    * RDD has a known partitioner by only searching the partition that the key maps to.
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 584efa8adf5f3877c914cf54b0fd22eefc325963..9bd8a0f98daa8a47c2b625f68eb1af4b02c0629a 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -420,7 +420,23 @@ abstract class RDD[T: ClassManifest](
   /**
    * Return an RDD with the elements from `this` that are not in `other`.
    */
-  def subtract(other: RDD[T], p: Partitioner): RDD[T] = new SubtractedRDD[T](this, other, p)
+  def subtract(other: RDD[T], p: Partitioner): RDD[T] = {
+    if (partitioner == Some(p)) {
+      // Our partitioner knows how to handle T (which, since we have a partitioner, is
+      // really (K, V)) so make a new Partitioner that will de-tuple our fake tuples
+      val p2 = new Partitioner() {
+        override def numPartitions = p.numPartitions
+        override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1)
+      }
+      // Unfortunately, since we're making a new p2, we'll get ShuffleDependencies
+      // anyway, and when calling .keys, will not have a partitioner set, even though
+      // the SubtractedRDD will, thanks to p2's de-tupled partitioning, already be
+      // partitioned by the right/real keys (e.g. p).
+      this.map(x => (x, null)).subtractByKey(other.map((_, null)), p2).keys
+    } else {
+      this.map(x => (x, null)).subtractByKey(other.map((_, null)), p).keys
+    }
+  }
 
   /**
    * Reduces the elements of this RDD using the specified commutative and associative binary operator.
diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
index 43ec90cac5a95111534a7e15d9638c438b71dde5..0a025610626779209c575caadbcd2067fb34a232 100644
--- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
@@ -1,7 +1,8 @@
 package spark.rdd
 
-import java.util.{HashSet => JHashSet}
+import java.util.{HashMap => JHashMap}
 import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
 import spark.RDD
 import spark.Partitioner
 import spark.Dependency
@@ -27,10 +28,10 @@ import spark.OneToOneDependency
  * you can use `rdd1`'s partitioner/partition size and not worry about running
  * out of memory because of the size of `rdd2`.
  */
-private[spark] class SubtractedRDD[T: ClassManifest](
-    @transient var rdd1: RDD[T],
-    @transient var rdd2: RDD[T],
-    part: Partitioner) extends RDD[T](rdd1.context, Nil) {
+private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest](
+    @transient var rdd1: RDD[(K, V)],
+    @transient var rdd2: RDD[(K, W)],
+    part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) {
 
   override def getDependencies: Seq[Dependency[_]] = {
     Seq(rdd1, rdd2).map { rdd =>
@@ -39,26 +40,7 @@ private[spark] class SubtractedRDD[T: ClassManifest](
         new OneToOneDependency(rdd)
       } else {
         logInfo("Adding shuffle dependency with " + rdd)
-        val mapSideCombinedRDD = rdd.mapPartitions(i => {
-          val set = new JHashSet[T]()
-          while (i.hasNext) {
-            set.add(i.next)
-          }
-          set.iterator
-        }, true)
-        // ShuffleDependency requires a tuple (k, v), which it will partition by k.
-        // We need this to partition to map to the same place as the k for
-        // OneToOneDependency, which means:
-        // - for already-tupled RDD[(A, B)], into getPartition(a)
-        // - for non-tupled RDD[C], into getPartition(c)
-        val part2 = new Partitioner() {
-          def numPartitions = part.numPartitions
-          def getPartition(key: Any) = key match {
-            case (k, v) => part.getPartition(k)
-            case k => part.getPartition(k)
-          }
-        }
-        new ShuffleDependency(mapSideCombinedRDD.map((_, null)), part2)
+        new ShuffleDependency(rdd.asInstanceOf[RDD[(K, Any)]], part)
       }
     }
   }
@@ -81,22 +63,32 @@ private[spark] class SubtractedRDD[T: ClassManifest](
 
   override val partitioner = Some(part)
 
-  override def compute(p: Partition, context: TaskContext): Iterator[T] = {
+  override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
     val partition = p.asInstanceOf[CoGroupPartition]
-    val set = new JHashSet[T]
-    def integrate(dep: CoGroupSplitDep, op: T => Unit) = dep match {
+    val map = new JHashMap[K, ArrayBuffer[V]]
+    def getSeq(k: K): ArrayBuffer[V] = {
+      val seq = map.get(k)
+      if (seq != null) {
+        seq
+      } else {
+        val seq = new ArrayBuffer[V]()
+        map.put(k, seq)
+        seq
+      }
+    }
+    def integrate(dep: CoGroupSplitDep, op: ((K, V)) => Unit) = dep match {
       case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
-        for (k <- rdd.iterator(itsSplit, context))
-          op(k.asInstanceOf[T])
+        for (t <- rdd.iterator(itsSplit, context))
+          op(t.asInstanceOf[(K, V)])
       case ShuffleCoGroupSplitDep(shuffleId) =>
-        for ((k, _) <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics))
-          op(k.asInstanceOf[T])
+        for (t <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics))
+          op(t.asInstanceOf[(K, V)])
     }
-    // the first dep is rdd1; add all keys to the set
-    integrate(partition.deps(0), set.add)
-    // the second dep is rdd2; remove all of its keys from the set
-    integrate(partition.deps(1), set.remove)
-    set.iterator
+    // the first dep is rdd1; add all values to the map
+    integrate(partition.deps(0), t => getSeq(t._1) += t._2)
+    // the second dep is rdd2; remove all of its keys
+    integrate(partition.deps(1), t => map.remove(t._1))
+    map.iterator.map { t =>  t._2.iterator.map { (t._1, _) } }.flatten
   }
 
   override def clearDependencies() {
@@ -105,4 +97,4 @@ private[spark] class SubtractedRDD[T: ClassManifest](
     rdd2 = null
   }
 
-}
\ No newline at end of file
+}
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index 8411291b2caa31e86f58bbf17f39fbf68020a669..2b2a90defa4e902a8db7fb5ab2bc13da411b5913 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -272,13 +272,39 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
     }
     // partitionBy so we have a narrow dependency
     val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
-    println(sc.runJob(a, (i: Iterator[(Int, String)]) => i.toList).toList)
     // more partitions/no partitioner so a shuffle dependency 
     val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
     val c = a.subtract(b)
     assert(c.collect().toSet === Set((1, "a"), (3, "c")))
+    // Ideally we could keep the original partitioner...
+    assert(c.partitioner === None)
+  }
+
+  test("subtractByKey") {
+    sc = new SparkContext("local", "test")
+    val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2)
+    val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4)
+    val c = a.subtractByKey(b)
+    assert(c.collect().toSet === Set((1, "a"), (1, "a")))
+    assert(c.partitions.size === a.partitions.size)
+  }
+
+  test("subtractByKey with narrow dependency") {
+    sc = new SparkContext("local", "test")
+    // use a deterministic partitioner
+    val p = new Partitioner() {
+      def numPartitions = 5
+      def getPartition(key: Any) = key.asInstanceOf[Int]
+    }
+    // partitionBy so we have a narrow dependency
+    val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
+    // more partitions/no partitioner so a shuffle dependency 
+    val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
+    val c = a.subtractByKey(b)
+    assert(c.collect().toSet === Set((1, "a"), (1, "a")))
     assert(c.partitioner.get === p)
   }
+
 }
 
 object ShuffleSuite {