From 5ff0810b11e95e3b48d88ae744fdeaf7c117186d Mon Sep 17 00:00:00 2001
From: Mark Hamstra <markhamstra@gmail.com>
Date: Tue, 5 Mar 2013 12:25:44 -0800
Subject: [PATCH] refactor mapWith, flatMapWith and filterWith to each use two
 parameter lists

---
 core/src/main/scala/spark/RDD.scala      | 12 ++++++------
 core/src/test/scala/spark/RDDSuite.scala | 12 ++++++------
 2 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index cc206782d0..0a901a251d 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -372,10 +372,10 @@ abstract class RDD[T: ClassManifest](
    * and a seed value of type B.
    */
   def mapWith[A: ClassManifest, B: ClassManifest, U: ClassManifest](
-    f:(A, T) => U,
     factoryBuilder: (Int, B) => (T => A),
     factorySeed: B,
-    preservesPartitioning: Boolean = false): RDD[U] = {
+    preservesPartitioning: Boolean = false)
+    (f:(A, T) => U): RDD[U] = {
       def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
         val factory = factoryBuilder(index, factorySeed)
         iter.map(t => f(factory(t), t))
@@ -391,10 +391,10 @@ abstract class RDD[T: ClassManifest](
    * and a seed value of type B.
    */
   def flatMapWith[A: ClassManifest, B: ClassManifest, U: ClassManifest](
-    f:(A, T) => Seq[U],
     factoryBuilder: (Int, B) => (T => A),
     factorySeed: B,
-    preservesPartitioning: Boolean = false): RDD[U] = {
+    preservesPartitioning: Boolean = false)
+    (f:(A, T) => Seq[U]): RDD[U] = {
       def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
         val factory = factoryBuilder(index, factorySeed)
         iter.flatMap(t => f(factory(t), t))
@@ -410,10 +410,10 @@ abstract class RDD[T: ClassManifest](
    * and a seed value of type B.
    */
   def filterWith[A: ClassManifest, B: ClassManifest](
-    p:(A, T) => Boolean,
     factoryBuilder: (Int, B) => (T => A),
     factorySeed: B,
-    preservesPartitioning: Boolean = false): RDD[T] = {
+    preservesPartitioning: Boolean = false)
+    (p:(A, T) => Boolean): RDD[T] = {
       def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
         val factory = factoryBuilder(index, factorySeed)
         iter.filter(t => p(factory(t), t))
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index b549677469..2a182e0d6c 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -183,11 +183,11 @@ class RDDSuite extends FunSuite with LocalSparkContext {
     sc = new SparkContext("local", "test")
     val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
     val randoms = ones.mapWith(
-      (random: Double, t: Int) => random * t,
       (index: Int, seed: Int) => {
 	      val prng = new java.util.Random(index + seed)
 	      (_ => prng.nextDouble)},
-      42).
+      42)
+      {(random: Double, t: Int) => random * t}.
       collect()
     val prn42_3 = {
       val prng42 = new java.util.Random(42)
@@ -205,11 +205,11 @@ class RDDSuite extends FunSuite with LocalSparkContext {
     sc = new SparkContext("local", "test")
     val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
     val randoms = ones.flatMapWith(
-      (random: Double, t: Int) => Seq(random * t, random * t * 10),
       (index: Int, seed: Int) => {
         val prng = new java.util.Random(index + seed)
         (_ => prng.nextDouble)},
-      42).
+      42)
+      {(random: Double, t: Int) => Seq(random * t, random * t * 10)}.
       collect()
     val prn42_3 = {
       val prng42 = new java.util.Random(42)
@@ -228,11 +228,11 @@ class RDDSuite extends FunSuite with LocalSparkContext {
     sc = new SparkContext("local", "test")
     val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2)
     val sample = ints.filterWith(
-      (random: Int, t: Int) => random == 0,
       (index: Int, seed: Int) => {
 	      val prng = new Random(index + seed)
 	      (_ => prng.nextInt(3))},
-      42).
+      42)
+      {(random: Int, t: Int) => random == 0}.
       collect()
     val checkSample = {
       val prng42 = new Random(42)
-- 
GitLab