diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index 1bd1741a713943499be7d3c700e91229f836998a..47b9c6962f951033bf5e3cf696b2179cd3163c91 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -440,6 +440,23 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
     cogroup(other1, other2, defaultPartitioner(self, other1, other2))
   }
 
+  /**
+   * Return an RDD with the pairs from `this` whose keys are not in `other`.
+   * 
+   * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
+   * RDD will be <= us.
+   */
+  def subtractByKey(other: RDD[(K, V)]): RDD[(K, V)] =
+    subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size)))
+
+  /** Return an RDD with the pairs from `this` whose keys are not in `other`. */
+  def subtractByKey(other: RDD[(K, V)], numPartitions: Int): RDD[(K, V)] =
+    subtractByKey(other, new HashPartitioner(numPartitions))
+
+  /** Return an RDD with the pairs from `this` whose keys are not in `other`. */
+  def subtractByKey(other: RDD[(K, V)], p: Partitioner): RDD[(K, V)] =
+    new SubtractedRDD[K, V](self, other, p)
+
   /**
    * Return the list of values in the RDD for key `key`. This operation is done efficiently if the
    * RDD has a known partitioner by only searching the partition that the key maps to.
@@ -639,8 +656,6 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
       }
     }, true)
   }
-
-  // def subtractByKey(other: RDD[K]): RDD[(K,V)] = subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size)))
 }
 
 private[spark]
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 3451136fd40eb3dd439b66da508c45fd4d42883f..9bd8a0f98daa8a47c2b625f68eb1af4b02c0629a 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -408,24 +408,8 @@ abstract class RDD[T: ClassManifest](
    * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
    * RDD will be <= us.
    */
-  def subtract(other: RDD[T]): RDD[T] = {
-    // If we do have a partitioner, our T is really (K, V), and we'll need to
-    // unwrap the (T, null) that subtract does to get back to the K
-    val rdd = subtract(other, partitioner match {
-      case None => new HashPartitioner(partitions.size)
-      case Some(p) => new Partitioner() {
-        override def numPartitions = p.numPartitions
-        override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1)
-      }
-    })
-    // Hacky, but if we did have a partitioner, we can keep using it
-    new RDD[T](rdd) {
-      override def getPartitions = rdd.partitions
-      override def getDependencies = rdd.dependencies
-      override def compute(split: Partition, context: TaskContext) = rdd.compute(split, context)
-      override val partitioner = RDD.this.partitioner
-    }
-  }
+  def subtract(other: RDD[T]): RDD[T] =
+    subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size)))
 
   /**
    * Return an RDD with the elements from `this` that are not in `other`.
@@ -437,7 +421,21 @@ abstract class RDD[T: ClassManifest](
    * Return an RDD with the elements from `this` that are not in `other`.
    */
   def subtract(other: RDD[T], p: Partitioner): RDD[T] = {
-     new SubtractedRDD[T, Any](this.map((_, null)), other.map((_, null)), p).keys
+    if (partitioner == Some(p)) {
+      // Our partitioner knows how to handle T (which, since we have a partitioner, is
+      // really (K, V)) so make a new Partitioner that will de-tuple our fake tuples
+      val p2 = new Partitioner() {
+        override def numPartitions = p.numPartitions
+        override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1)
+      }
+      // Unfortunately, since we're making a new p2, we'll get ShuffleDependencies
+      // anyway, and when calling .keys, will not have a partitioner set, even though
+      // the SubtractedRDD will, thanks to p2's de-tupled partitioning, already be
+      // partitioned by the right/real keys (e.g. p).
+      this.map(x => (x, null)).subtractByKey(other.map((_, null)), p2).keys
+    } else {
+      this.map(x => (x, null)).subtractByKey(other.map((_, null)), p).keys
+    }
   }
 
   /**
diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
index 1bc84f7e1ef82a2c1bb8a5b4bb50911bdea9b948..90488f13cce2c9fd4e169f12d7a2322b97a1d9e3 100644
--- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
@@ -87,7 +87,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest](
     // the first dep is rdd1; add all values to the map
     integrate(partition.deps(0), t => getSeq(t._1) += t._2)
     // the second dep is rdd2; remove all of its keys
-    integrate(partition.deps(1), t => map.remove(t._1) )
+    integrate(partition.deps(1), t => map.remove(t._1))
     map.iterator.map { t =>  t._2.iterator.map { (t._1, _) } }.flatten
   }
 
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index 8411291b2caa31e86f58bbf17f39fbf68020a669..731c45cca2fe6875474ea55a340cbe245403f2ac 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -272,13 +272,39 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
     }
     // partitionBy so we have a narrow dependency
     val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
-    println(sc.runJob(a, (i: Iterator[(Int, String)]) => i.toList).toList)
     // more partitions/no partitioner so a shuffle dependency 
     val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
     val c = a.subtract(b)
     assert(c.collect().toSet === Set((1, "a"), (3, "c")))
+    // Ideally we could keep the original partitioner...
+    assert(c.partitioner === None)
+  }
+
+  test("subtractByKey") {
+    sc = new SparkContext("local", "test")
+    val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2)
+    val b = sc.parallelize(Array((2, "bb"), (3, "cc"), (4, "dd")), 4)
+    val c = a.subtractByKey(b)
+    assert(c.collect().toSet === Set((1, "a"), (1, "a")))
+    assert(c.partitions.size === a.partitions.size)
+  }
+
+  test("subtractByKey with narrow dependency") {
+    sc = new SparkContext("local", "test")
+    // use a deterministic partitioner
+    val p = new Partitioner() {
+      def numPartitions = 5
+      def getPartition(key: Any) = key.asInstanceOf[Int]
+    }
+    // partitionBy so we have a narrow dependency
+    val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
+    // more partitions/no partitioner so a shuffle dependency 
+    val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
+    val c = a.subtractByKey(b)
+    assert(c.collect().toSet === Set((1, "a"), (1, "a")))
     assert(c.partitioner.get === p)
   }
+
 }
 
 object ShuffleSuite {