From d7f78b443b7c31b4db4eabb106801dc4a1866db7 Mon Sep 17 00:00:00 2001
From: Josh Rosen <joshrosen@eecs.berkeley.edu>
Date: Sun, 11 Aug 2013 12:05:09 -0700
Subject: [PATCH] Change scala.Option to Guava Optional in Java APIs.

---
 .../scala/spark/api/java/JavaPairRDD.scala    | 38 +++++++++++++------
 .../scala/spark/api/java/JavaRDDLike.scala    |  5 +--
 .../spark/api/java/JavaSparkContext.scala     |  4 +-
 .../main/scala/spark/api/java/JavaUtils.scala | 28 ++++++++++++++
 core/src/test/scala/spark/JavaAPISuite.java   | 30 +++++++++++++++
 .../streaming/api/java/JavaPairDStream.scala  |  7 +---
 .../tools/JavaAPICompletenessChecker.scala    | 34 +++++++++--------
 7 files changed, 109 insertions(+), 37 deletions(-)
 create mode 100644 core/src/main/scala/spark/api/java/JavaUtils.scala

diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
index ccc511dc5f..6e00ef955a 100644
--- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
@@ -41,6 +41,8 @@ import spark.Partitioner._
 import spark.RDD
 import spark.SparkContext.rddToPairRDDFunctions
 
+import com.google.common.base.Optional
+
 class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManifest[K],
   implicit val vManifest: ClassManifest[V]) extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] {
 
@@ -276,8 +278,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
    * partition the output RDD.
    */
   def leftOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner)
-  : JavaPairRDD[K, (V, Option[W])] =
-    fromRDD(rdd.leftOuterJoin(other, partitioner))
+  : JavaPairRDD[K, (V, Optional[W])] = {
+    val joinResult = rdd.leftOuterJoin(other, partitioner)
+    fromRDD(joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))})
+  }
 
   /**
    * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
@@ -286,8 +290,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
    * partition the output RDD.
    */
   def rightOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner)
-  : JavaPairRDD[K, (Option[V], W)] =
-    fromRDD(rdd.rightOuterJoin(other, partitioner))
+  : JavaPairRDD[K, (Optional[V], W)] = {
+    val joinResult = rdd.rightOuterJoin(other, partitioner)
+    fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)})
+  }
 
   /** 
    * Simplified version of combineByKey that hash-partitions the resulting RDD using the existing
@@ -340,8 +346,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
    * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
    * using the existing partitioner/parallelism level.
    */
-  def leftOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, Option[W])] =
-    fromRDD(rdd.leftOuterJoin(other))
+  def leftOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, Optional[W])] = {
+    val joinResult = rdd.leftOuterJoin(other)
+    fromRDD(joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))})
+  }
 
   /**
    * Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
@@ -349,8 +357,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
    * pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
    * into `numPartitions` partitions.
    */
-  def leftOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (V, Option[W])] =
-    fromRDD(rdd.leftOuterJoin(other, numPartitions))
+  def leftOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (V, Optional[W])] = {
+    val joinResult = rdd.leftOuterJoin(other, numPartitions)
+    fromRDD(joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))})
+  }
 
   /**
    * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
@@ -358,8 +368,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
    * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
    * RDD using the existing partitioner/parallelism level.
    */
-  def rightOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (Option[V], W)] =
-    fromRDD(rdd.rightOuterJoin(other))
+  def rightOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (Optional[V], W)] = {
+    val joinResult = rdd.rightOuterJoin(other)
+    fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)})
+  }
 
   /**
    * Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
@@ -367,8 +379,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
    * pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
    * RDD into the given number of partitions.
    */
-  def rightOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (Option[V], W)] =
-    fromRDD(rdd.rightOuterJoin(other, numPartitions))
+  def rightOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (Optional[V], W)] = {
+    val joinResult = rdd.rightOuterJoin(other, numPartitions)
+    fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)})
+  }
 
   /**
    * Return the key-value pairs in this RDD to the master as a Map.
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
index 21b5abf053..e0255ed23e 100644
--- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -366,10 +366,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
    * Gets the name of the file to which this RDD was checkpointed
    */
   def getCheckpointFile(): Optional[String] = {
-    rdd.getCheckpointFile match {
-      case Some(file) => Optional.of(file)
-      case _ => Optional.absent()
-    }
+    JavaUtils.optionToOptional(rdd.getCheckpointFile)
   }
 
   /** A description of this RDD and its recursive dependencies for debugging. */
diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
index fe182e7ab6..29d57004b5 100644
--- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
@@ -32,6 +32,8 @@ import spark.SparkContext.IntAccumulatorParam
 import spark.SparkContext.DoubleAccumulatorParam
 import spark.broadcast.Broadcast
 
+import com.google.common.base.Optional
+
 /**
  * A Java-friendly version of [[spark.SparkContext]] that returns [[spark.api.java.JavaRDD]]s and
  * works with Java collections instead of Scala ones.
@@ -337,7 +339,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
    * or the spark.home Java property, or the SPARK_HOME environment variable
    * (in that order of preference). If neither of these is set, return None.
    */
-  def getSparkHome(): Option[String] = sc.getSparkHome()
+  def getSparkHome(): Optional[String] = JavaUtils.optionToOptional(sc.getSparkHome())
 
   /**
    * Add a file to be downloaded with this Spark job on every node.
diff --git a/core/src/main/scala/spark/api/java/JavaUtils.scala b/core/src/main/scala/spark/api/java/JavaUtils.scala
new file mode 100644
index 0000000000..ffc131ac83
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/JavaUtils.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.api.java
+
+import com.google.common.base.Optional
+
+object JavaUtils {
+  def optionToOptional[T](option: Option[T]): Optional[T] =
+    option match {
+      case Some(value) => Optional.of(value)
+      case None => Optional.absent()
+    }
+}
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
index 5e2bf2d231..4ab271de1a 100644
--- a/core/src/test/scala/spark/JavaAPISuite.java
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -22,6 +22,7 @@ import java.io.IOException;
 import java.io.Serializable;
 import java.util.*;
 
+import com.google.common.base.Optional;
 import scala.Tuple2;
 
 import com.google.common.base.Charsets;
@@ -197,6 +198,35 @@ public class JavaAPISuite implements Serializable {
     cogrouped.collect();
   }
 
+  @Test
+  public void leftOuterJoin() {
+    JavaPairRDD<Integer, Integer> rdd1 = sc.parallelizePairs(Arrays.asList(
+      new Tuple2<Integer, Integer>(1, 1),
+      new Tuple2<Integer, Integer>(1, 2),
+      new Tuple2<Integer, Integer>(2, 1),
+      new Tuple2<Integer, Integer>(3, 1)
+      ));
+    JavaPairRDD<Integer, Character> rdd2 = sc.parallelizePairs(Arrays.asList(
+      new Tuple2<Integer, Character>(1, 'x'),
+      new Tuple2<Integer, Character>(2, 'y'),
+      new Tuple2<Integer, Character>(2, 'z'),
+      new Tuple2<Integer, Character>(4, 'w')
+    ));
+    List<Tuple2<Integer,Tuple2<Integer,Optional<Character>>>> joined =
+      rdd1.leftOuterJoin(rdd2).collect();
+    Assert.assertEquals(5, joined.size());
+    Tuple2<Integer,Tuple2<Integer,Optional<Character>>> firstUnmatched =
+      rdd1.leftOuterJoin(rdd2).filter(
+        new Function<Tuple2<Integer, Tuple2<Integer, Optional<Character>>>, Boolean>() {
+          @Override
+          public Boolean call(Tuple2<Integer, Tuple2<Integer, Optional<Character>>> tup)
+            throws Exception {
+            return !tup._2()._2().isPresent();
+          }
+      }).first();
+    Assert.assertEquals(3, firstUnmatched._1().intValue());
+  }
+
   @Test
   public void foldReduce() {
     JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13));
diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala
index ccd15563b0..ea08fb3826 100644
--- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala
@@ -29,7 +29,7 @@ import spark.{RDD, Partitioner}
 import org.apache.hadoop.mapred.{JobConf, OutputFormat}
 import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
 import org.apache.hadoop.conf.Configuration
-import spark.api.java.{JavaRDD, JavaPairRDD}
+import spark.api.java.{JavaUtils, JavaRDD, JavaPairRDD}
 import spark.storage.StorageLevel
 import com.google.common.base.Optional
 import spark.RDD
@@ -401,10 +401,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
   (Seq[V], Option[S]) => Option[S] = {
     val scalaFunc: (Seq[V], Option[S]) => Option[S] = (values, state) => {
       val list: JList[V] = values
-      val scalaState: Optional[S] = state match {
-        case Some(s) => Optional.of(s)
-        case _ => Optional.absent()
-      }
+      val scalaState: Optional[S] = JavaUtils.optionToOptional(state)
       val result: Optional[S] = in.apply(list, scalaState)
       result.isPresent match {
         case true => Some(result.get())
diff --git a/tools/src/main/scala/spark/tools/JavaAPICompletenessChecker.scala b/tools/src/main/scala/spark/tools/JavaAPICompletenessChecker.scala
index 3a55f50812..30fded12f0 100644
--- a/tools/src/main/scala/spark/tools/JavaAPICompletenessChecker.scala
+++ b/tools/src/main/scala/spark/tools/JavaAPICompletenessChecker.scala
@@ -121,7 +121,7 @@ object JavaAPICompletenessChecker {
     SparkMethod(name, returnType, parameters)
   }
 
-  private def toJavaType(scalaType: SparkType): SparkType = {
+  private def toJavaType(scalaType: SparkType, isReturnType: Boolean): SparkType = {
     val renameSubstitutions = Map(
       "scala.collection.Map" -> "java.util.Map",
       // TODO: the JavaStreamingContext API accepts Array arguments
@@ -140,40 +140,43 @@ object JavaAPICompletenessChecker {
             case "spark.RDD" =>
               if (parameters(0).name == classOf[Tuple2[_, _]].getName) {
                 val tupleParams =
-                  parameters(0).asInstanceOf[ParameterizedType].parameters.map(toJavaType)
+                  parameters(0).asInstanceOf[ParameterizedType].parameters.map(applySubs)
                 ParameterizedType(classOf[JavaPairRDD[_, _]].getName, tupleParams)
               } else {
-                ParameterizedType(classOf[JavaRDD[_]].getName, parameters.map(toJavaType))
+                ParameterizedType(classOf[JavaRDD[_]].getName, parameters.map(applySubs))
               }
             case "spark.streaming.DStream" =>
               if (parameters(0).name == classOf[Tuple2[_, _]].getName) {
                 val tupleParams =
-                  parameters(0).asInstanceOf[ParameterizedType].parameters.map(toJavaType)
+                  parameters(0).asInstanceOf[ParameterizedType].parameters.map(applySubs)
                 ParameterizedType("spark.streaming.api.java.JavaPairDStream", tupleParams)
               } else {
                 ParameterizedType("spark.streaming.api.java.JavaDStream",
-                  parameters.map(toJavaType))
+                  parameters.map(applySubs))
               }
-            // TODO: Spark Streaming uses Guava's Optional in place of Option, leading to some
-            // false-positives here:
-            case "scala.Option" =>
-              toJavaType(parameters(0))
+            case "scala.Option" => {
+              if (isReturnType) {
+                ParameterizedType("com.google.common.base.Optional", parameters.map(applySubs))
+              } else {
+                applySubs(parameters(0))
+              }
+            }
             case "scala.Function1" =>
               val firstParamName = parameters.last.name
               if (firstParamName.startsWith("scala.collection.Traversable") ||
                 firstParamName.startsWith("scala.collection.Iterator")) {
                 ParameterizedType("spark.api.java.function.FlatMapFunction",
                   Seq(parameters(0),
-                    parameters.last.asInstanceOf[ParameterizedType].parameters(0)).map(toJavaType))
+                    parameters.last.asInstanceOf[ParameterizedType].parameters(0)).map(applySubs))
               } else if (firstParamName == "scala.runtime.BoxedUnit") {
                 ParameterizedType("spark.api.java.function.VoidFunction",
-                  parameters.dropRight(1).map(toJavaType))
+                  parameters.dropRight(1).map(applySubs))
               } else {
-                ParameterizedType("spark.api.java.function.Function", parameters.map(toJavaType))
+                ParameterizedType("spark.api.java.function.Function", parameters.map(applySubs))
               }
             case _ =>
               ParameterizedType(renameSubstitutions.getOrElse(name, name),
-                parameters.map(toJavaType))
+                parameters.map(applySubs))
           }
         case BaseType(name) =>
           if (renameSubstitutions.contains(name)) {
@@ -194,8 +197,9 @@ object JavaAPICompletenessChecker {
 
   private def toJavaMethod(method: SparkMethod): SparkMethod = {
     val params = method.parameters
-      .filterNot(_.name == "scala.reflect.ClassManifest").map(toJavaType)
-    SparkMethod(method.name, toJavaType(method.returnType), params)
+      .filterNot(_.name == "scala.reflect.ClassManifest")
+      .map(toJavaType(_, isReturnType = false))
+    SparkMethod(method.name, toJavaType(method.returnType, isReturnType = true), params)
   }
 
   private def isExcludedByName(method: Method): Boolean = {
-- 
GitLab