diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index e7408e4352abfc2e53e20f91039e186ff72c139f..1bd1741a713943499be7d3c700e91229f836998a 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -639,6 +639,8 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
       }
     }, true)
   }
+
+  // def subtractByKey(other: RDD[K]): RDD[(K,V)] = subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size)))
 }
 
 private[spark]
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 584efa8adf5f3877c914cf54b0fd22eefc325963..3451136fd40eb3dd439b66da508c45fd4d42883f 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -408,8 +408,24 @@ abstract class RDD[T: ClassManifest](
    * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
    * RDD will be <= us.
    */
-  def subtract(other: RDD[T]): RDD[T] =
-    subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size)))
+  def subtract(other: RDD[T]): RDD[T] = {
+    // If we do have a partitioner, our T is really (K, V), and we'll need to
+    // unwrap the (T, null) that subtract does to get back to the K
+    val rdd = subtract(other, partitioner match {
+      case None => new HashPartitioner(partitions.size)
+      case Some(p) => new Partitioner() {
+        override def numPartitions = p.numPartitions
+        override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1)
+      }
+    })
+    // Hacky, but if we did have a partitioner, we can keep using it
+    new RDD[T](rdd) {
+      override def getPartitions = rdd.partitions
+      override def getDependencies = rdd.dependencies
+      override def compute(split: Partition, context: TaskContext) = rdd.compute(split, context)
+      override val partitioner = RDD.this.partitioner
+    }
+  }
 
   /**
    * Return an RDD with the elements from `this` that are not in `other`.
@@ -420,7 +436,9 @@ abstract class RDD[T: ClassManifest](
   /**
    * Return an RDD with the elements from `this` that are not in `other`.
    */
-  def subtract(other: RDD[T], p: Partitioner): RDD[T] = new SubtractedRDD[T](this, other, p)
+  def subtract(other: RDD[T], p: Partitioner): RDD[T] = {
+     new SubtractedRDD[T, Any](this.map((_, null)), other.map((_, null)), p).keys
+  }
 
   /**
    * Reduces the elements of this RDD using the specified commutative and associative binary operator.
diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
index 43ec90cac5a95111534a7e15d9638c438b71dde5..1bc84f7e1ef82a2c1bb8a5b4bb50911bdea9b948 100644
--- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
@@ -1,7 +1,8 @@
 package spark.rdd
 
-import java.util.{HashSet => JHashSet}
+import java.util.{HashMap => JHashMap}
 import scala.collection.JavaConversions._
+import scala.collection.mutable.ArrayBuffer
 import spark.RDD
 import spark.Partitioner
 import spark.Dependency
@@ -27,39 +28,20 @@ import spark.OneToOneDependency
  * you can use `rdd1`'s partitioner/partition size and not worry about running
  * out of memory because of the size of `rdd2`.
  */
-private[spark] class SubtractedRDD[T: ClassManifest](
-    @transient var rdd1: RDD[T],
-    @transient var rdd2: RDD[T],
-    part: Partitioner) extends RDD[T](rdd1.context, Nil) {
+private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest](
+    @transient var rdd1: RDD[(K, V)],
+    @transient var rdd2: RDD[(K, V)],
+    part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) {
 
   override def getDependencies: Seq[Dependency[_]] = {
     Seq(rdd1, rdd2).map { rdd =>
-      if (rdd.partitioner == Some(part)) {
-        logInfo("Adding one-to-one dependency with " + rdd)
-        new OneToOneDependency(rdd)
-      } else {
-        logInfo("Adding shuffle dependency with " + rdd)
-        val mapSideCombinedRDD = rdd.mapPartitions(i => {
-          val set = new JHashSet[T]()
-          while (i.hasNext) {
-            set.add(i.next)
-          }
-          set.iterator
-        }, true)
-        // ShuffleDependency requires a tuple (k, v), which it will partition by k.
-        // We need this to partition to map to the same place as the k for
-        // OneToOneDependency, which means:
-        // - for already-tupled RDD[(A, B)], into getPartition(a)
-        // - for non-tupled RDD[C], into getPartition(c)
-        val part2 = new Partitioner() {
-          def numPartitions = part.numPartitions
-          def getPartition(key: Any) = key match {
-            case (k, v) => part.getPartition(k)
-            case k => part.getPartition(k)
-          }
-        }
-        new ShuffleDependency(mapSideCombinedRDD.map((_, null)), part2)
-      }
+	  if (rdd.partitioner == Some(part)) {
+	    logInfo("Adding one-to-one dependency with " + rdd)
+	    new OneToOneDependency(rdd)
+	  } else {
+	    logInfo("Adding shuffle dependency with " + rdd)
+	    new ShuffleDependency(rdd, part)
+	  }
     }
   }
 
@@ -81,22 +63,32 @@ private[spark] class SubtractedRDD[T: ClassManifest](
 
   override val partitioner = Some(part)
 
-  override def compute(p: Partition, context: TaskContext): Iterator[T] = {
+  override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
     val partition = p.asInstanceOf[CoGroupPartition]
-    val set = new JHashSet[T]
-    def integrate(dep: CoGroupSplitDep, op: T => Unit) = dep match {
+    val map = new JHashMap[K, ArrayBuffer[V]]
+    def getSeq(k: K): ArrayBuffer[V] = {
+      val seq = map.get(k)
+      if (seq != null) {
+        seq
+      } else {
+        val seq = new ArrayBuffer[V]()
+        map.put(k, seq)
+        seq
+      }
+    }
+    def integrate(dep: CoGroupSplitDep, op: ((K, V)) => Unit) = dep match {
       case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
-        for (k <- rdd.iterator(itsSplit, context))
-          op(k.asInstanceOf[T])
+        for (t <- rdd.iterator(itsSplit, context))
+          op(t.asInstanceOf[(K, V)])
       case ShuffleCoGroupSplitDep(shuffleId) =>
-        for ((k, _) <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics))
-          op(k.asInstanceOf[T])
+        for (t <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics))
+          op(t.asInstanceOf[(K, V)])
     }
-    // the first dep is rdd1; add all keys to the set
-    integrate(partition.deps(0), set.add)
-    // the second dep is rdd2; remove all of its keys from the set
-    integrate(partition.deps(1), set.remove)
-    set.iterator
+    // the first dep is rdd1; add all values to the map
+    integrate(partition.deps(0), t => getSeq(t._1) += t._2)
+    // the second dep is rdd2; remove all of its keys
+    integrate(partition.deps(1), t => map.remove(t._1) )
+    map.iterator.map { t =>  t._2.iterator.map { (t._1, _) } }.flatten
   }
 
   override def clearDependencies() {