Skip to content
Snippets Groups Projects
Commit 9a344098 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Merge pull request #360 from rxin/cogroup-java

Changed CoGroupRDD's hash map from Scala to Java.
parents 62e47673 be716614
No related branches found
No related tags found
No related merge requests found
package spark.rdd package spark.rdd
import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext} import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Split, TaskContext}
import spark.{Dependency, OneToOneDependency, ShuffleDependency} import spark.{Dependency, OneToOneDependency, ShuffleDependency}
...@@ -71,9 +72,16 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) ...@@ -71,9 +72,16 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = { override def compute(s: Split, context: TaskContext): Iterator[(K, Seq[Seq[_]])] = {
val split = s.asInstanceOf[CoGroupSplit] val split = s.asInstanceOf[CoGroupSplit]
val numRdds = split.deps.size val numRdds = split.deps.size
val map = new HashMap[K, Seq[ArrayBuffer[Any]]] val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
def getSeq(k: K): Seq[ArrayBuffer[Any]] = { def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any])) val seq = map.get(k)
if (seq != null) {
seq
} else {
val seq = Array.fill(numRdds)(new ArrayBuffer[Any])
map.put(k, seq)
seq
}
} }
for ((dep, depNum) <- split.deps.zipWithIndex) dep match { for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, itsSplit) => { case NarrowCoGroupSplitDep(rdd, itsSplit) => {
...@@ -93,6 +101,6 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) ...@@ -93,6 +101,6 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner)
fetcher.fetch[K, Seq[Any]](shuffleId, split.index).foreach(mergePair) fetcher.fetch[K, Seq[Any]](shuffleId, split.index).foreach(mergePair)
} }
} }
map.iterator JavaConversions.mapAsScalaMap(map).iterator
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment