diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index de0d9fad88ee32554dd799635d639c460063f2c5..ce5f1719117a6c606c1b2d2fcad12594a8ba9907 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -1,7 +1,8 @@
 package spark.rdd
 
+import java.util.{HashMap => JHashMap}
+import scala.collection.JavaConversions
 import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
 
 import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext}
 import spark.{Dependency, OneToOneDependency, ShuffleDependency}
@@ -71,9 +72,16 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
   override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
     val split = s.asInstanceOf[CoGroupSplit]
     val numRdds = split.deps.size
-    val map = new HashMap[K, Seq[ArrayBuffer[Any]]]
+    val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
     def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
-      map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any]))
+      val seq = map.get(k)
+      if (seq != null) {
+        seq
+      } else {
+        val seq = Array.fill(numRdds)(new ArrayBuffer[Any])
+        map.put(k, seq)
+        seq
+      }
     }
     for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
       case NarrowCoGroupSplitDep(rdd, itsSplit) => {
@@ -93,6 +101,6 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
         fetcher.fetch[K, Seq[Any]](shuffleId, split.index).foreach(mergePair)
       }
     }
-    map.iterator
+    JavaConversions.mapAsScalaMap(map).iterator
   }
 }