diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index 8fafd27bb6985a1e707c79d25cc74ed91dd9cf72..4893fe8d784bf14cba83d55298c5bd533083bbda 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -84,6 +84,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
   override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
     val split = s.asInstanceOf[CoGroupSplit]
     val numRdds = split.deps.size
+    // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
     val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
     def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
       val seq = map.get(k)
@@ -104,13 +105,10 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner)
       }
       case ShuffleCoGroupSplitDep(shuffleId) => {
         // Read map outputs of shuffle
-        def mergePair(pair: (K, Seq[Any])) {
-          val mySeq = getSeq(pair._1)
-          for (v <- pair._2)
-            mySeq(depNum) += v
-        }
         val fetcher = SparkEnv.get.shuffleFetcher
-        fetcher.fetch[K, Seq[Any]](shuffleId, split.index).foreach(mergePair)
+        for ((k, vs) <- fetcher.fetch[K, Seq[Any]](shuffleId, split.index)) {
+          getSeq(k)(depNum) ++= vs
+        }
       }
     }
     JavaConversions.mapAsScalaMap(map).iterator