diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index b9fe7f604ee989c82b6e5c446f2781f209bc25d1..6fd7a0d15a09cb3ced2af1d12d46916341c2bbbc 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -228,6 +228,31 @@ class SparkContext(
         scheduler.initialize(backend)
         scheduler
 
+      case "yarn-client" =>
+        val scheduler = try {
+          val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler")
+          val cons = clazz.getConstructor(classOf[SparkContext])
+          cons.newInstance(this).asInstanceOf[ClusterScheduler]
+
+        } catch {
+          case th: Throwable => {
+            throw new SparkException("YARN mode not available ?", th)
+          }
+        }
+
+        val backend = try {
+          val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend")
+          val cons = clazz.getConstructor(classOf[ClusterScheduler], classOf[SparkContext])
+          cons.newInstance(scheduler, this).asInstanceOf[CoarseGrainedSchedulerBackend]
+        } catch {
+          case th: Throwable => {
+            throw new SparkException("YARN mode not available ?", th)
+          }
+        }
+
+        scheduler.initialize(backend)
+        scheduler
+
       case MESOS_REGEX(mesosUrl) =>
         MesosNativeLibrary.load()
         val scheduler = new ClusterScheduler(this)
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index e5e20dbb66a5b3b250d12fff55f068cdd9343b8d..da30cf619a1d0ecfabf501faecc0e2b0f0a64738 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -29,6 +29,8 @@ import org.apache.spark.storage.StorageLevel
 import java.lang.Double
 import org.apache.spark.Partitioner
 
+import scala.collection.JavaConverters._
+
 class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, JavaDoubleRDD] {
 
   override val classTag: ClassTag[Double] = implicitly[ClassTag[Double]]
@@ -185,6 +187,44 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
 
   /** (Experimental) Approximate operation to return the sum within a timeout. */
   def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout)
+
+  /**
+   * Compute a histogram of the data using bucketCount number of buckets evenly
+   *  spaced between the minimum and maximum of the RDD. For example if the min
+   *  value is 0 and the max is 100 and there are two buckets the resulting
+   *  buckets will be [0,50) [50,100]. bucketCount must be at least 1
+   * If the RDD contains infinity, NaN throws an exception
+   * If the elements in RDD do not vary (max == min) always returns a single bucket.
+   */
+  def histogram(bucketCount: Int): Pair[Array[scala.Double], Array[Long]] = {
+    val result = srdd.histogram(bucketCount)
+    (result._1, result._2)
+  }
+
+  /**
+   * Compute a histogram using the provided buckets. The buckets are all open
+   * to the left except for the last which is closed
+   *  e.g. for the array
+   *  [1,10,20,50] the buckets are [1,10) [10,20) [20,50]
+   *  e.g 1<=x<10 , 10<=x<20, 20<=x<50
+   *  And on the input of 1 and 50 we would have a histogram of 1,0,0 
+   * 
+   * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched
+   * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets
+   * to true.
+   * buckets must be sorted and not contain any duplicates.
+   * buckets array must be at least two elements 
+   * All NaN entries are treated the same. If you have a NaN bucket it must be
+   * the maximum value of the last position and all NaN entries will be counted
+   * in that bucket.
+   */
+  def histogram(buckets: Array[scala.Double]): Array[Long] = {
+    srdd.histogram(buckets, false)
+  }
+
+  def histogram(buckets: Array[Double], evenBuckets: Boolean): Array[Long] = {
+    srdd.histogram(buckets.map(_.toDouble), evenBuckets)
+  }
 }
 
 object JavaDoubleRDD {
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 53b53df9ac7712666ef6189449b28b467b38f235..2bf7ac256eb92243a69746be876bc1bf4a448060 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -28,12 +28,11 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark._
 import org.apache.spark.rdd.RDD
-import org.apache.spark.rdd.PipedRDD
 import org.apache.spark.util.Utils
 
 private[spark] class PythonRDD[T: ClassTag](
     parent: RDD[T],
-    command: Seq[String],
+    command: Array[Byte],
     envVars: JMap[String, String],
     pythonIncludes: JList[String],
     preservePartitoning: Boolean,
@@ -44,21 +43,10 @@ private[spark] class PythonRDD[T: ClassTag](
 
   val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
 
-  // Similar to Runtime.exec(), if we are given a single string, split it into words
-  // using a standard StringTokenizer (i.e. by spaces)
-  def this(parent: RDD[T], command: String, envVars: JMap[String, String],
-      pythonIncludes: JList[String],
-      preservePartitoning: Boolean, pythonExec: String,
-      broadcastVars: JList[Broadcast[Array[Byte]]],
-      accumulator: Accumulator[JList[Array[Byte]]]) =
-    this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec,
-      broadcastVars, accumulator)
-
   override def getPartitions = parent.partitions
 
   override val partitioner = if (preservePartitoning) parent.partitioner else None
 
-
   override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
     val startTime = System.currentTimeMillis
     val env = SparkEnv.get
@@ -71,11 +59,10 @@ private[spark] class PythonRDD[T: ClassTag](
           SparkEnv.set(env)
           val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
           val dataOut = new DataOutputStream(stream)
-          val printOut = new PrintWriter(stream)
           // Partition index
           dataOut.writeInt(split.index)
           // sparkFilesDir
-          PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
+          dataOut.writeUTF(SparkFiles.getRootDirectory)
           // Broadcast variables
           dataOut.writeInt(broadcastVars.length)
           for (broadcast <- broadcastVars) {
@@ -85,21 +72,16 @@ private[spark] class PythonRDD[T: ClassTag](
           }
           // Python includes (*.zip and *.egg files)
           dataOut.writeInt(pythonIncludes.length)
-          for (f <- pythonIncludes) {
-            PythonRDD.writeAsPickle(f, dataOut)
-          }
+          pythonIncludes.foreach(dataOut.writeUTF)
           dataOut.flush()
-          // Serialized user code
-          for (elem <- command) {
-            printOut.println(elem)
-          }
-          printOut.flush()
+          // Serialized command:
+          dataOut.writeInt(command.length)
+          dataOut.write(command)
           // Data values
           for (elem <- parent.iterator(split, context)) {
-            PythonRDD.writeAsPickle(elem, dataOut)
+            PythonRDD.writeToStream(elem, dataOut)
           }
           dataOut.flush()
-          printOut.flush()
           worker.shutdownOutput()
         } catch {
           case e: IOException =>
@@ -132,7 +114,7 @@ private[spark] class PythonRDD[T: ClassTag](
               val obj = new Array[Byte](length)
               stream.readFully(obj)
               obj
-            case -3 =>
+            case SpecialLengths.TIMING_DATA =>
               // Timing data from worker
               val bootTime = stream.readLong()
               val initTime = stream.readLong()
@@ -143,24 +125,24 @@ private[spark] class PythonRDD[T: ClassTag](
               val total = finishTime - startTime
               logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish))
               read
-            case -2 =>
+            case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
               // Signals that an exception has been thrown in python
               val exLength = stream.readInt()
               val obj = new Array[Byte](exLength)
               stream.readFully(obj)
               throw new PythonException(new String(obj))
-            case -1 =>
+            case SpecialLengths.END_OF_DATA_SECTION =>
               // We've finished the data section of the output, but we can still
-              // read some accumulator updates; let's do that, breaking when we
-              // get a negative length record.
-              var len2 = stream.readInt()
-              while (len2 >= 0) {
-                val update = new Array[Byte](len2)
+              // read some accumulator updates:
+              val numAccumulatorUpdates = stream.readInt()
+              (1 to numAccumulatorUpdates).foreach { _ =>
+                val updateLen = stream.readInt()
+                val update = new Array[Byte](updateLen)
                 stream.readFully(update)
                 accumulator += Collections.singletonList(update)
-                len2 = stream.readInt()
+
               }
-              new Array[Byte](0)
+              Array.empty[Byte]
           }
         } catch {
           case eof: EOFException => {
@@ -197,62 +179,15 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
   val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this)
 }
 
-private[spark] object PythonRDD {
-
-  /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */
-  def stripPickle(arr: Array[Byte]) : Array[Byte] = {
-    arr.slice(2, arr.length - 1)
-  }
+private object SpecialLengths {
+  val END_OF_DATA_SECTION = -1
+  val PYTHON_EXCEPTION_THROWN = -2
+  val TIMING_DATA = -3
+}
 
-  /**
-   * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream.
-   * The data format is a 32-bit integer representing the pickled object's length (in bytes),
-   * followed by the pickled data.
-   *
-   * Pickle module:
-   *
-   *    http://docs.python.org/2/library/pickle.html
-   *
-   * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules:
-   *
-   *    http://hg.python.org/cpython/file/2.6/Lib/pickle.py
-   *    http://hg.python.org/cpython/file/2.6/Lib/pickletools.py
-   *
-   * @param elem the object to write
-   * @param dOut a data output stream
-   */
-  def writeAsPickle(elem: Any, dOut: DataOutputStream) {
-    if (elem.isInstanceOf[Array[Byte]]) {
-      val arr = elem.asInstanceOf[Array[Byte]]
-      dOut.writeInt(arr.length)
-      dOut.write(arr)
-    } else if (elem.isInstanceOf[scala.Tuple2[_, _]]) {
-      val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
-      val length = t._1.length + t._2.length - 3 - 3 + 4  // stripPickle() removes 3 bytes
-      dOut.writeInt(length)
-      dOut.writeByte(Pickle.PROTO)
-      dOut.writeByte(Pickle.TWO)
-      dOut.write(PythonRDD.stripPickle(t._1))
-      dOut.write(PythonRDD.stripPickle(t._2))
-      dOut.writeByte(Pickle.TUPLE2)
-      dOut.writeByte(Pickle.STOP)
-    } else if (elem.isInstanceOf[String]) {
-      // For uniformity, strings are wrapped into Pickles.
-      val s = elem.asInstanceOf[String].getBytes("UTF-8")
-      val length = 2 + 1 + 4 + s.length + 1
-      dOut.writeInt(length)
-      dOut.writeByte(Pickle.PROTO)
-      dOut.writeByte(Pickle.TWO)
-      dOut.write(Pickle.BINUNICODE)
-      dOut.writeInt(Integer.reverseBytes(s.length))
-      dOut.write(s)
-      dOut.writeByte(Pickle.STOP)
-    } else {
-      throw new SparkException("Unexpected RDD type")
-    }
-  }
+private[spark] object PythonRDD {
 
-  def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
+  def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
   JavaRDD[Array[Byte]] = {
     val file = new DataInputStream(new FileInputStream(filename))
     val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
@@ -270,15 +205,32 @@ private[spark] object PythonRDD {
     JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
   }
 
-  def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
+  def writeToStream(elem: Any, dataOut: DataOutputStream) {
+    elem match {
+      case bytes: Array[Byte] =>
+        dataOut.writeInt(bytes.length)
+        dataOut.write(bytes)
+      case pair: (Array[Byte], Array[Byte]) =>
+        dataOut.writeInt(pair._1.length)
+        dataOut.write(pair._1)
+        dataOut.writeInt(pair._2.length)
+        dataOut.write(pair._2)
+      case str: String =>
+        dataOut.writeUTF(str)
+      case other =>
+        throw new SparkException("Unexpected element type " + other.getClass)
+    }
+  }
+
+  def writeToFile[T](items: java.util.Iterator[T], filename: String) {
     import scala.collection.JavaConverters._
-    writeIteratorToPickleFile(items.asScala, filename)
+    writeToFile(items.asScala, filename)
   }
 
-  def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) {
+  def writeToFile[T](items: Iterator[T], filename: String) {
     val file = new DataOutputStream(new FileOutputStream(filename))
     for (item <- items) {
-      writeAsPickle(item, file)
+      writeToStream(item, file)
     }
     file.close()
   }
@@ -289,17 +241,6 @@ private[spark] object PythonRDD {
   }
 }
 
-private object Pickle {
-  val PROTO: Byte = 0x80.toByte
-  val TWO: Byte = 0x02.toByte
-  val BINUNICODE: Byte = 'X'
-  val STOP: Byte = '.'
-  val TUPLE2: Byte = 0x86.toByte
-  val EMPTY_LIST: Byte = ']'
-  val MARK: Byte = '('
-  val APPENDS: Byte = 'e'
-}
-
 private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
   override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
 }
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 0b4892f98f03976e683476d052d1e33b38f15782..c0ce46e379344ef24d9ee0cfc538cb4fdd6b9062 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -61,50 +61,53 @@ object TaskMetrics {
 
 class ShuffleReadMetrics extends Serializable {
   /**
-   * Time when shuffle finishs
+   * Absolute time when this task finished reading shuffle data
    */
   var shuffleFinishTime: Long = _
 
   /**
-   * Total number of blocks fetched in a shuffle (remote or local)
+   * Number of blocks fetched in this shuffle by this task (remote or local)
    */
   var totalBlocksFetched: Int = _
 
   /**
-   * Number of remote blocks fetched in a shuffle
+   * Number of remote blocks fetched in this shuffle by this task
    */
   var remoteBlocksFetched: Int = _
 
   /**
-   * Local blocks fetched in a shuffle
+   * Number of local blocks fetched in this shuffle by this task
    */
   var localBlocksFetched: Int = _
 
   /**
-   * Total time that is spent blocked waiting for shuffle to fetch data
+   * Time the task spent waiting for remote shuffle blocks. This only includes the time
+   * blocking on shuffle input data. For instance if block B is being fetched while the task is
+   * still not finished processing block A, it is not considered to be blocking on block B.
    */
   var fetchWaitTime: Long = _
 
   /**
-   * The total amount of time for all the shuffle fetches.  This adds up time from overlapping
-   *     shuffles, so can be longer than task time
+   * Total time spent fetching remote shuffle blocks. This aggregates the time spent fetching all
+   * input blocks. Since block fetches are both pipelined and parallelized, this can
+   * exceed fetchWaitTime and executorRunTime.
    */
   var remoteFetchTime: Long = _
 
   /**
-   * Total number of remote bytes read from a shuffle
+   * Total number of remote bytes read from the shuffle by this task
    */
   var remoteBytesRead: Long = _
 }
 
 class ShuffleWriteMetrics extends Serializable {
   /**
-   * Number of bytes written for a shuffle
+   * Number of bytes written for the shuffle by this task
    */
   var shuffleBytesWritten: Long = _
 
   /**
-   * Time spent blocking on writes to disk or buffer cache, in nanoseconds.
+   * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds
    */
   var shuffleWriteTime: Long = _
 }
diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
index a4bec417529fc9022b9da540993bc235cae67b81..02d75eccc535e8736e80993219d76a1895240ddc 100644
--- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
@@ -24,6 +24,8 @@ import org.apache.spark.partial.SumEvaluator
 import org.apache.spark.util.StatCounter
 import org.apache.spark.{TaskContext, Logging}
 
+import scala.collection.immutable.NumericRange
+
 /**
  * Extra functions available on RDDs of Doubles through an implicit conversion.
  * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
@@ -76,4 +78,128 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
     val evaluator = new SumEvaluator(self.partitions.size, confidence)
     self.context.runApproximateJob(self, processPartition, evaluator, timeout)
   }
+
+  /**
+   * Compute a histogram of the data using bucketCount number of buckets evenly
+   *  spaced between the minimum and maximum of the RDD. For example if the min
+   *  value is 0 and the max is 100 and there are two buckets the resulting
+   *  buckets will be [0, 50) [50, 100]. bucketCount must be at least 1
+   * If the RDD contains infinity, NaN throws an exception
+   * If the elements in RDD do not vary (max == min) always returns a single bucket.
+   */
+  def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = {
+    // Compute the minimum and the maxium
+    val (max: Double, min: Double) = self.mapPartitions { items =>
+      Iterator(items.foldRight(-1/0.0, Double.NaN)((e: Double, x: Pair[Double, Double]) =>
+        (x._1.max(e), x._2.min(e))))
+    }.reduce { (maxmin1, maxmin2) =>
+      (maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2))
+    }
+    if (max.isNaN() || max.isInfinity || min.isInfinity ) {
+      throw new UnsupportedOperationException(
+        "Histogram on either an empty RDD or RDD containing +/-infinity or NaN")
+    }
+    val increment = (max-min)/bucketCount.toDouble
+    val range = if (increment != 0) {
+      Range.Double.inclusive(min, max, increment)
+    } else {
+      List(min, min)
+    }
+    val buckets = range.toArray
+    (buckets, histogram(buckets, true))
+  }
+
+  /**
+   * Compute a histogram using the provided buckets. The buckets are all open
+   * to the left except for the last which is closed
+   *  e.g. for the array
+   *  [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50]
+   *  e.g 1<=x<10 , 10<=x<20, 20<=x<50
+   *  And on the input of 1 and 50 we would have a histogram of 1, 0, 0 
+   * 
+   * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched
+   * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets
+   * to true.
+   * buckets must be sorted and not contain any duplicates.
+   * buckets array must be at least two elements 
+   * All NaN entries are treated the same. If you have a NaN bucket it must be
+   * the maximum value of the last position and all NaN entries will be counted
+   * in that bucket.
+   */
+  def histogram(buckets: Array[Double], evenBuckets: Boolean = false): Array[Long] = {
+    if (buckets.length < 2) {
+      throw new IllegalArgumentException("buckets array must have at least two elements")
+    }
+    // The histogramPartition function computes the partail histogram for a given
+    // partition. The provided bucketFunction determines which bucket in the array
+    // to increment or returns None if there is no bucket. This is done so we can
+    // specialize for uniformly distributed buckets and save the O(log n) binary
+    // search cost.
+    def histogramPartition(bucketFunction: (Double) => Option[Int])(iter: Iterator[Double]):
+        Iterator[Array[Long]] = {
+      val counters = new Array[Long](buckets.length - 1)
+      while (iter.hasNext) {
+        bucketFunction(iter.next()) match {
+          case Some(x: Int) => {counters(x) += 1}
+          case _ => {}
+        }
+      }
+      Iterator(counters)
+    }
+    // Merge the counters.
+    def mergeCounters(a1: Array[Long], a2: Array[Long]): Array[Long] = {
+      a1.indices.foreach(i => a1(i) += a2(i))
+      a1
+    }
+    // Basic bucket function. This works using Java's built in Array
+    // binary search. Takes log(size(buckets))
+    def basicBucketFunction(e: Double): Option[Int] = {
+      val location = java.util.Arrays.binarySearch(buckets, e)
+      if (location < 0) {
+        // If the location is less than 0 then the insertion point in the array
+        // to keep it sorted is -location-1
+        val insertionPoint = -location-1
+        // If we have to insert before the first element or after the last one
+        // its out of bounds.
+        // We do this rather than buckets.lengthCompare(insertionPoint)
+        // because Array[Double] fails to override it (for now).
+        if (insertionPoint > 0 && insertionPoint < buckets.length) {
+          Some(insertionPoint-1)
+        } else {
+          None
+        }
+      } else if (location < buckets.length - 1) {
+        // Exact match, just insert here
+        Some(location)
+      } else {
+        // Exact match to the last element
+        Some(location - 1)
+      }
+    }
+    // Determine the bucket function in constant time. Requires that buckets are evenly spaced
+    def fastBucketFunction(min: Double, increment: Double, count: Int)(e: Double): Option[Int] = {
+      // If our input is not a number unless the increment is also NaN then we fail fast
+      if (e.isNaN()) {
+        return None
+      }
+      val bucketNumber = (e - min)/(increment)
+      // We do this rather than buckets.lengthCompare(bucketNumber)
+      // because Array[Double] fails to override it (for now).
+      if (bucketNumber > count || bucketNumber < 0) {
+        None
+      } else {
+        Some(bucketNumber.toInt.min(count - 1))
+      }
+    }
+    // Decide which bucket function to pass to histogramPartition. We decide here
+    // rather than having a general function so that the decission need only be made
+    // once rather than once per shard
+    val bucketFunction = if (evenBuckets) {
+      fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _
+    } else {
+      basicBucketFunction _
+    }
+    self.mapPartitions(histogramPartition(bucketFunction)).reduce(mergeCounters)
+  }
+
 }
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
index cdb5946b49366da715d71287c8c67c9c02799703..db15baf503ad61e1253e988f1b22dd5084a3b094 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
@@ -20,19 +20,16 @@ package org.apache.spark.rdd
 import org.apache.spark.{Partition, TaskContext}
 import scala.reflect.ClassTag
 
-
-private[spark]
-class MapPartitionsRDD[U: ClassTag, T: ClassTag](
+private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
     prev: RDD[T],
-    f: Iterator[T] => Iterator[U],
+    f: (TaskContext, Int, Iterator[T]) => Iterator[U],  // (TaskContext, partition index, iterator)
     preservesPartitioning: Boolean = false)
   extends RDD[U](prev) {
 
-  override val partitioner =
-    if (preservesPartitioning) firstParent[T].partitioner else None
+  override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
 
   override def getPartitions: Array[Partition] = firstParent[T].partitions
 
   override def compute(split: Partition, context: TaskContext) =
-    f(firstParent[T].iterator(split, context))
+    f(context, split.index, firstParent[T].iterator(split, context))
 }
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
deleted file mode 100644
index 67636751bb8b2401acd9a6db3d028de202a641d7..0000000000000000000000000000000000000000
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
+++ /dev/null
@@ -1,42 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import org.apache.spark.{Partition, TaskContext}
-import scala.reflect.ClassTag
-
-
-/**
- * A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the
- * TaskContext, the closure can either get access to the interruptible flag or get the index
- * of the partition in the RDD.
- */
-private[spark]
-class MapPartitionsWithContextRDD[U: ClassTag, T: ClassTag](
-    prev: RDD[T],
-    f: (TaskContext, Iterator[T]) => Iterator[U],
-    preservesPartitioning: Boolean
-  ) extends RDD[U](prev) {
-
-  override def getPartitions: Array[Partition] = firstParent[T].partitions
-
-  override val partitioner = if (preservesPartitioning) prev.partitioner else None
-
-  override def compute(split: Partition, context: TaskContext) =
-    f(context, firstParent[T].iterator(split, context))
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index da18d45e65deb911e0dfe160b61dc38593de8989..f80d3d601c0eac286b1eda9d752b8d1a04a03cf4 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -411,7 +411,6 @@ abstract class RDD[T: ClassTag](
   def pipe(command: String, env: Map[String, String]): RDD[String] =
     new PipedRDD(this, command, env)
 
-
   /**
    * Return an RDD created by piping elements to a forked external process.
    * The print behavior can be customized by providing two functions.
@@ -443,9 +442,10 @@ abstract class RDD[T: ClassTag](
   /**
    * Return a new RDD by applying a function to each partition of this RDD.
    */
-  def mapPartitions[U: ClassTag](f: Iterator[T] => Iterator[U],
-    preservesPartitioning: Boolean = false): RDD[U] = {
-    new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
+  def mapPartitions[U: ClassTag](
+      f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
+    val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter)
+    new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
   }
 
   /**
@@ -454,8 +454,8 @@ abstract class RDD[T: ClassTag](
    */
   def mapPartitionsWithIndex[U: ClassTag](
       f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
-    val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter)
-    new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning)
+    val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter)
+    new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
   }
 
   /**
@@ -465,7 +465,8 @@ abstract class RDD[T: ClassTag](
   def mapPartitionsWithContext[U: ClassTag](
       f: (TaskContext, Iterator[T]) => Iterator[U],
       preservesPartitioning: Boolean = false): RDD[U] = {
-    new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning)
+    val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(context, iter)
+    new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
   }
 
   /**
@@ -486,11 +487,10 @@ abstract class RDD[T: ClassTag](
   def mapWith[A: ClassTag, U: ClassTag]
       (constructA: Int => A, preservesPartitioning: Boolean = false)
       (f: (T, A) => U): RDD[U] = {
-    def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
-      val a = constructA(context.partitionId)
+    mapPartitionsWithIndex((index, iter) => {
+      val a = constructA(index)
       iter.map(t => f(t, a))
-    }
-    new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
+    }, preservesPartitioning)
   }
 
   /**
@@ -501,11 +501,10 @@ abstract class RDD[T: ClassTag](
   def flatMapWith[A: ClassTag, U: ClassTag]
       (constructA: Int => A, preservesPartitioning: Boolean = false)
       (f: (T, A) => Seq[U]): RDD[U] = {
-    def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
-      val a = constructA(context.partitionId)
+    mapPartitionsWithIndex((index, iter) => {
+      val a = constructA(index)
       iter.flatMap(t => f(t, a))
-    }
-    new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
+    }, preservesPartitioning)
   }
 
   /**
@@ -514,11 +513,10 @@ abstract class RDD[T: ClassTag](
    * partition with the index of that partition.
    */
   def foreachWith[A: ClassTag](constructA: Int => A)(f: (T, A) => Unit) {
-    def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
-      val a = constructA(context.partitionId)
+    mapPartitionsWithIndex { (index, iter) =>
+      val a = constructA(index)
       iter.map(t => {f(t, a); t})
-    }
-    new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {})
+    }.foreach(_ => {})
   }
 
   /**
@@ -527,11 +525,10 @@ abstract class RDD[T: ClassTag](
    * partition with the index of that partition.
    */
   def filterWith[A: ClassTag](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
-    def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
-      val a = constructA(context.partitionId)
+    mapPartitionsWithIndex((index, iter) => {
+      val a = constructA(index)
       iter.filter(t => p(t, a))
-    }
-    new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true)
+    }, preservesPartitioning = true)
   }
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 773e9ec182c8b7710dcc01770974335f4378a20a..201572d16ac2dad6d4b5df4906a78b4c94d9a43c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -112,6 +112,9 @@ class DAGScheduler(
   // resubmit failed stages
   val POLL_TIMEOUT = 10L
 
+  // Warns the user if a stage contains a task with size greater than this value (in KB)
+  val TASK_SIZE_TO_WARN = 100
+
   private val eventProcessActor: ActorRef = env.actorSystem.actorOf(Props(new Actor {
     override def preStart() {
       import context.dispatcher
@@ -433,6 +436,18 @@ class DAGScheduler(
         handleExecutorLost(execId)
 
       case BeginEvent(task, taskInfo) =>
+        for (
+          job <- idToActiveJob.get(task.stageId);
+          stage <- stageIdToStage.get(task.stageId);
+          stageInfo <- stageToInfos.get(stage)
+        ) {
+          if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 && !stageInfo.emittedTaskSizeWarning) {
+            stageInfo.emittedTaskSizeWarning = true
+            logWarning(("Stage %d (%s) contains a task of very large " +
+              "size (%d KB). The maximum recommended task size is %d KB.").format(
+              task.stageId, stageInfo.name, taskInfo.serializedSize / 1024, TASK_SIZE_TO_WARN))
+          }
+        }
         listenerBus.post(SparkListenerTaskStart(task, taskInfo))
 
       case GettingResultEvent(task, taskInfo) =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index 93599dfdc85fff358d497176bdf92519fbea9426..e9f2198a007e526a237f7190e542761bc191c8af 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -33,4 +33,5 @@ class StageInfo(
   val name = stage.name
   val numPartitions = stage.numPartitions
   val numTasks = stage.numTasks
+  var emittedTaskSizeWarning = false
 }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
index 4bae26f3a6a885c73bd1639d61d226cbd06a5ea2..3c22edd5248f403190a2f543597728d08dba92a2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
@@ -46,6 +46,8 @@ class TaskInfo(
 
   var failed = false
 
+  var serializedSize: Int = 0
+
   def markGettingResult(time: Long = System.currentTimeMillis) {
     gettingResultTime = time
   }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
index 4c5eca8537cd62044a27d835bea520691d74931f..8884ea85a34e980796c891a14575f2983216f708 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -377,6 +377,7 @@ private[spark] class ClusterTaskSetManager(
           logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
             taskSet.id, index, serializedTask.limit, timeTaken))
           val taskName = "task %s:%d".format(taskSet.id, index)
+          info.serializedSize = serializedTask.limit
           if (taskAttempts(index).size == 1)
             taskStarted(task,info)
           return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
index 49d95afdb936460d6ff8f551f92032da6f41d81c..87e009a4de93d8ab71cd2a0fb36393c67ebc9888 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
@@ -80,6 +80,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
   protected var _capacity = nextPowerOf2(initialCapacity)
   protected var _mask = _capacity - 1
   protected var _size = 0
+  protected var _growThreshold = (loadFactor * _capacity).toInt
 
   protected var _bitset = new BitSet(_capacity)
 
@@ -116,7 +117,29 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
    * @return The position where the key is placed, plus the highest order bit is set if the key
    *         exists previously.
    */
-  def addWithoutResize(k: T): Int = putInto(_bitset, _data, k)
+  def addWithoutResize(k: T): Int = {
+    var pos = hashcode(hasher.hash(k)) & _mask
+    var i = 1
+    while (true) {
+      if (!_bitset.get(pos)) {
+        // This is a new key.
+        _data(pos) = k
+        _bitset.set(pos)
+        _size += 1
+        return pos | NONEXISTENCE_MASK
+      } else if (_data(pos) == k) {
+        // Found an existing key.
+        return pos
+      } else {
+        val delta = i
+        pos = (pos + delta) & _mask
+        i += 1
+      }
+    }
+    // Never reached here
+    assert(INVALID_POS != INVALID_POS)
+    INVALID_POS
+  }
 
   /**
    * Rehash the set if it is overloaded.
@@ -127,7 +150,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
    *                 to a new position (in the new data array).
    */
   def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
-    if (_size > loadFactor * _capacity) {
+    if (_size > _growThreshold) {
       rehash(k, allocateFunc, moveFunc)
     }
   }
@@ -161,37 +184,6 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
    */
   def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos)
 
-  /**
-   * Put an entry into the set. Return the position where the key is placed. In addition, the
-   * highest bit in the returned position is set if the key exists prior to this put.
-   *
-   * This function assumes the data array has at least one empty slot.
-   */
-  private def putInto(bitset: BitSet, data: Array[T], k: T): Int = {
-    val mask = data.length - 1
-    var pos = hashcode(hasher.hash(k)) & mask
-    var i = 1
-    while (true) {
-      if (!bitset.get(pos)) {
-        // This is a new key.
-        data(pos) = k
-        bitset.set(pos)
-        _size += 1
-        return pos | NONEXISTENCE_MASK
-      } else if (data(pos) == k) {
-        // Found an existing key.
-        return pos
-      } else {
-        val delta = i
-        pos = (pos + delta) & mask
-        i += 1
-      }
-    }
-    // Never reached here
-    assert(INVALID_POS != INVALID_POS)
-    INVALID_POS
-  }
-
   /**
    * Double the table's size and re-hash everything. We are not really using k, but it is declared
    * so Scala compiler can specialize this method (which leads to calling the specialized version
@@ -205,34 +197,49 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
    */
   private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
     val newCapacity = _capacity * 2
-    require(newCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
-
     allocateFunc(newCapacity)
-    val newData = new Array[T](newCapacity)
     val newBitset = new BitSet(newCapacity)
-    var pos = 0
-    _size = 0
-    while (pos < _capacity) {
-      if (_bitset.get(pos)) {
-        val newPos = putInto(newBitset, newData, _data(pos))
-        moveFunc(pos, newPos & POSITION_MASK)
+    val newData = new Array[T](newCapacity)
+    val newMask = newCapacity - 1
+
+    var oldPos = 0
+    while (oldPos < capacity) {
+      if (_bitset.get(oldPos)) {
+        val key = _data(oldPos)
+        var newPos = hashcode(hasher.hash(key)) & newMask
+        var i = 1
+        var keepGoing = true
+        // No need to check for equality here when we insert so this has one less if branch than
+        // the similar code path in addWithoutResize.
+        while (keepGoing) {
+          if (!newBitset.get(newPos)) {
+            // Inserting the key at newPos
+            newData(newPos) = key
+            newBitset.set(newPos)
+            moveFunc(oldPos, newPos)
+            keepGoing = false
+          } else {
+            val delta = i
+            newPos = (newPos + delta) & newMask
+            i += 1
+          }
+        }
       }
-      pos += 1
+      oldPos += 1
     }
+
     _bitset = newBitset
     _data = newData
     _capacity = newCapacity
-    _mask = newCapacity - 1
+    _mask = newMask
+    _growThreshold = (loadFactor * newCapacity).toInt
   }
 
   /**
-   * Re-hash a value to deal better with hash functions that don't differ
-   * in the lower bits, similar to java.util.HashMap
+   * Re-hash a value to deal better with hash functions that don't differ in the lower bits.
+   * We use the Murmur Hash 3 finalization step that's also used in fastutil.
    */
-  private def hashcode(h: Int): Int = {
-    val r = h ^ (h >>> 20) ^ (h >>> 12)
-    r ^ (r >>> 7) ^ (r >>> 4)
-  }
+  private def hashcode(h: Int): Int = it.unimi.dsi.fastutil.HashCommon.murmurHash3(h)
 
   private def nextPowerOf2(n: Int): Int = {
     val highBit = Integer.highestOneBit(n)
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index fcfc2c9893e965df3667613d91278a083d569a50..f25d921d3f87faa004975500e56c8b96416c3a13 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -63,8 +63,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
     testCheckpointing(_.sample(false, 0.5, 0))
     testCheckpointing(_.glom())
     testCheckpointing(_.mapPartitions(_.map(_.toString)))
-    testCheckpointing(r => new MapPartitionsWithContextRDD(r,
-      (context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), false ))
     testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
     testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
     testCheckpointing(_.pipe(Seq("cat")))
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
index 352036f182e24c676c9792f83b369d43f0fdb48b..4234f6eac72f4cebd63232e7c2d33c2521805ce1 100644
--- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -364,6 +364,20 @@ public class JavaAPISuite implements Serializable {
     List<Double> take = rdd.take(5);
   }
 
+  @Test
+  public void javaDoubleRDDHistoGram() {
+   JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0));
+   // Test using generated buckets
+   Tuple2<double[], long[]> results = rdd.histogram(2);
+   double[] expected_buckets = {1.0, 2.5, 4.0};
+   long[] expected_counts = {2, 2};
+   Assert.assertArrayEquals(expected_buckets, results._1, 0.1);
+   Assert.assertArrayEquals(expected_counts, results._2);
+   // Test with provided buckets
+   long[] histogram = rdd.histogram(expected_buckets);
+   Assert.assertArrayEquals(expected_counts, histogram);
+  }
+
   @Test
   public void map() {
     JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..7f50a5a47c2ff3ef5f909469dca837bc399171f8
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.math.abs
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd._
+import org.apache.spark._
+
+class DoubleRDDSuite extends FunSuite with SharedSparkContext {
+  // Verify tests on the histogram functionality. We test with both evenly
+  // and non-evenly spaced buckets as the bucket lookup function changes.
+  test("WorksOnEmpty") {
+    // Make sure that it works on an empty input
+    val rdd: RDD[Double] = sc.parallelize(Seq())
+    val buckets = Array(0.0, 10.0)
+    val histogramResults = rdd.histogram(buckets)
+    val histogramResults2 = rdd.histogram(buckets, true)
+    val expectedHistogramResults = Array(0)
+    assert(histogramResults === expectedHistogramResults)
+    assert(histogramResults2 === expectedHistogramResults)
+  }
+
+  test("WorksWithOutOfRangeWithOneBucket") {
+    // Verify that if all of the elements are out of range the counts are zero
+    val rdd = sc.parallelize(Seq(10.01, -0.01))
+    val buckets = Array(0.0, 10.0)
+    val histogramResults = rdd.histogram(buckets)
+    val histogramResults2 = rdd.histogram(buckets, true)
+    val expectedHistogramResults = Array(0)
+    assert(histogramResults === expectedHistogramResults)
+    assert(histogramResults2 === expectedHistogramResults)
+  }
+
+  test("WorksInRangeWithOneBucket") {
+    // Verify the basic case of one bucket and all elements in that bucket works
+    val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+    val buckets = Array(0.0, 10.0)
+    val histogramResults = rdd.histogram(buckets)
+    val histogramResults2 = rdd.histogram(buckets, true)
+    val expectedHistogramResults = Array(4)
+    assert(histogramResults === expectedHistogramResults)
+    assert(histogramResults2 === expectedHistogramResults)
+  }
+
+  test("WorksInRangeWithOneBucketExactMatch") {
+    // Verify the basic case of one bucket and all elements in that bucket works
+    val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+    val buckets = Array(1.0, 4.0)
+    val histogramResults = rdd.histogram(buckets)
+    val histogramResults2 = rdd.histogram(buckets, true)
+    val expectedHistogramResults = Array(4)
+    assert(histogramResults === expectedHistogramResults)
+    assert(histogramResults2 === expectedHistogramResults)
+  }
+
+  test("WorksWithOutOfRangeWithTwoBuckets") {
+    // Verify that out of range works with two buckets
+    val rdd = sc.parallelize(Seq(10.01, -0.01))
+    val buckets = Array(0.0, 5.0, 10.0)
+    val histogramResults = rdd.histogram(buckets)
+    val histogramResults2 = rdd.histogram(buckets, true)
+    val expectedHistogramResults = Array(0, 0)
+    assert(histogramResults === expectedHistogramResults)
+    assert(histogramResults2 === expectedHistogramResults)
+  }
+
+  test("WorksWithOutOfRangeWithTwoUnEvenBuckets") {
+    // Verify that out of range works with two un even buckets
+    val rdd = sc.parallelize(Seq(10.01, -0.01))
+    val buckets = Array(0.0, 4.0, 10.0)
+    val histogramResults = rdd.histogram(buckets)
+    val expectedHistogramResults = Array(0, 0)
+    assert(histogramResults === expectedHistogramResults)
+  }
+
+  test("WorksInRangeWithTwoBuckets") {
+    // Make sure that it works with two equally spaced buckets and elements in each
+    val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6))
+    val buckets = Array(0.0, 5.0, 10.0)
+    val histogramResults = rdd.histogram(buckets)
+    val histogramResults2 = rdd.histogram(buckets, true)
+    val expectedHistogramResults = Array(3, 2)
+    assert(histogramResults === expectedHistogramResults)
+    assert(histogramResults2 === expectedHistogramResults)
+  }
+
+  test("WorksInRangeWithTwoBucketsAndNaN") {
+    // Make sure that it works with two equally spaced buckets and elements in each
+    val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6, Double.NaN))
+    val buckets = Array(0.0, 5.0, 10.0)
+    val histogramResults = rdd.histogram(buckets)
+    val histogramResults2 = rdd.histogram(buckets, true)
+    val expectedHistogramResults = Array(3, 2)
+    assert(histogramResults === expectedHistogramResults)
+    assert(histogramResults2 === expectedHistogramResults)
+  }
+
+  test("WorksInRangeWithTwoUnevenBuckets") {
+    // Make sure that it works with two unequally spaced buckets and elements in each
+    val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6))
+    val buckets = Array(0.0, 5.0, 11.0)
+    val histogramResults = rdd.histogram(buckets)
+    val expectedHistogramResults = Array(3, 2)
+    assert(histogramResults === expectedHistogramResults)
+  }
+
+  test("WorksMixedRangeWithTwoUnevenBuckets") {
+    // Make sure that it works with two unequally spaced buckets and elements in each
+    val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01))
+    val buckets = Array(0.0, 5.0, 11.0)
+    val histogramResults = rdd.histogram(buckets)
+    val expectedHistogramResults = Array(4, 3)
+    assert(histogramResults === expectedHistogramResults)
+  }
+
+  test("WorksMixedRangeWithFourUnevenBuckets") {
+    // Make sure that it works with two unequally spaced buckets and elements in each
+    val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+      200.0, 200.1))
+    val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0)
+    val histogramResults = rdd.histogram(buckets)
+    val expectedHistogramResults = Array(4, 2, 1, 3)
+    assert(histogramResults === expectedHistogramResults)
+  }
+
+  test("WorksMixedRangeWithUnevenBucketsAndNaN") {
+    // Make sure that it works with two unequally spaced buckets and elements in each
+    val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+      200.0, 200.1, Double.NaN))
+    val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0)
+    val histogramResults = rdd.histogram(buckets)
+    val expectedHistogramResults = Array(4, 2, 1, 3)
+    assert(histogramResults === expectedHistogramResults)
+  }
+  // Make sure this works with a NaN end bucket
+  test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRange") {
+    // Make sure that it works with two unequally spaced buckets and elements in each
+    val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+      200.0, 200.1, Double.NaN))
+    val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN)
+    val histogramResults = rdd.histogram(buckets)
+    val expectedHistogramResults = Array(4, 2, 1, 2, 3)
+    assert(histogramResults === expectedHistogramResults)
+  }
+  // Make sure this works with a NaN end bucket and an inifity
+  test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRangeAndInfity") {
+    // Make sure that it works with two unequally spaced buckets and elements in each
+    val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+      200.0, 200.1, 1.0/0.0, -1.0/0.0, Double.NaN))
+    val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN)
+    val histogramResults = rdd.histogram(buckets)
+    val expectedHistogramResults = Array(4, 2, 1, 2, 4)
+    assert(histogramResults === expectedHistogramResults)
+  }
+
+  test("WorksWithOutOfRangeWithInfiniteBuckets") {
+    // Verify that out of range works with two buckets
+    val rdd = sc.parallelize(Seq(10.01, -0.01, Double.NaN))
+    val buckets = Array(-1.0/0.0 , 0.0, 1.0/0.0)
+    val histogramResults = rdd.histogram(buckets)
+    val expectedHistogramResults = Array(1, 1)
+    assert(histogramResults === expectedHistogramResults)
+  }
+  // Test the failure mode with an invalid bucket array
+  test("ThrowsExceptionOnInvalidBucketArray") {
+    val rdd = sc.parallelize(Seq(1.0))
+    // Empty array
+    intercept[IllegalArgumentException] {
+      val buckets = Array.empty[Double]
+      val result = rdd.histogram(buckets)
+    }
+    // Single element array
+    intercept[IllegalArgumentException] {
+      val buckets = Array(1.0)
+      val result = rdd.histogram(buckets)
+    }
+  }
+
+  // Test automatic histogram function
+  test("WorksWithoutBucketsBasic") {
+    // Verify the basic case of one bucket and all elements in that bucket works
+    val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+    val (histogramBuckets, histogramResults) = rdd.histogram(1)
+    val expectedHistogramResults = Array(4)
+    val expectedHistogramBuckets = Array(1.0, 4.0)
+    assert(histogramResults === expectedHistogramResults)
+    assert(histogramBuckets === expectedHistogramBuckets)
+  }
+  // Test automatic histogram function with a single element
+  test("WorksWithoutBucketsBasicSingleElement") {
+    // Verify the basic case of one bucket and all elements in that bucket works
+    val rdd = sc.parallelize(Seq(1))
+    val (histogramBuckets, histogramResults) = rdd.histogram(1)
+    val expectedHistogramResults = Array(1)
+    val expectedHistogramBuckets = Array(1.0, 1.0)
+    assert(histogramResults === expectedHistogramResults)
+    assert(histogramBuckets === expectedHistogramBuckets)
+  }
+  // Test automatic histogram function with a single element
+  test("WorksWithoutBucketsBasicNoRange") {
+    // Verify the basic case of one bucket and all elements in that bucket works
+    val rdd = sc.parallelize(Seq(1, 1, 1, 1))
+    val (histogramBuckets, histogramResults) = rdd.histogram(1)
+    val expectedHistogramResults = Array(4)
+    val expectedHistogramBuckets = Array(1.0, 1.0)
+    assert(histogramResults === expectedHistogramResults)
+    assert(histogramBuckets === expectedHistogramBuckets)
+  }
+
+  test("WorksWithoutBucketsBasicTwo") {
+    // Verify the basic case of one bucket and all elements in that bucket works
+    val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+    val (histogramBuckets, histogramResults) = rdd.histogram(2)
+    val expectedHistogramResults = Array(2, 2)
+    val expectedHistogramBuckets = Array(1.0, 2.5, 4.0)
+    assert(histogramResults === expectedHistogramResults)
+    assert(histogramBuckets === expectedHistogramBuckets)
+  }
+
+  test("WorksWithoutBucketsWithMoreRequestedThanElements") {
+    // Verify the basic case of one bucket and all elements in that bucket works
+    val rdd = sc.parallelize(Seq(1, 2))
+    val (histogramBuckets, histogramResults) = rdd.histogram(10)
+    val expectedHistogramResults =
+      Array(1, 0, 0, 0, 0, 0, 0, 0, 0, 1)
+    val expectedHistogramBuckets =
+      Array(1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0)
+    assert(histogramResults === expectedHistogramResults)
+    assert(histogramBuckets === expectedHistogramBuckets)
+  }
+
+  // Test the failure mode with an invalid RDD
+  test("ThrowsExceptionOnInvalidRDDs") {
+    // infinity
+    intercept[UnsupportedOperationException] {
+      val rdd = sc.parallelize(Seq(1, 1.0/0.0))
+      val result = rdd.histogram(1)
+    }
+    // NaN
+    intercept[UnsupportedOperationException] {
+      val rdd = sc.parallelize(Seq(1, Double.NaN))
+      val result = rdd.histogram(1)
+    }
+    // Empty
+    intercept[UnsupportedOperationException] {
+      val rdd: RDD[Double] = sc.parallelize(Seq())
+      val result = rdd.histogram(1)
+    }
+  }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
index 984881861c9a985a3dc92950c0f7759a37fa949a..002368ff554f7226cc4f7f6b1643f8c1f61df0c8 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.rdd.RDD
 
 
 class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+  val WAIT_TIMEOUT_MILLIS = 10000
 
   test("inner method") {
     sc = new SparkContext("local", "joblogger")
@@ -92,6 +93,8 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
     val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
     rdd.reduceByKey(_+_).collect()
 
+    assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+
     val user = System.getProperty("user.name", SparkContext.SPARK_UNKNOWN_USER)
     
     joblogger.getLogDir should be ("/tmp/spark-%s".format(user))
@@ -120,7 +123,9 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
     sc.addSparkListener(joblogger)
     val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
     rdd.reduceByKey(_+_).collect()
-    
+
+    assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+
     joblogger.onJobStartCount should be (1)
     joblogger.onJobEndCount should be (1)
     joblogger.onTaskEndCount should be (8)
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
index ca3f684668d605e868d491fffb5b4f0bcc4a23a1..63e874fed3942965d9c7659b2cf86913120e124f 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
@@ -2,8 +2,20 @@ package org.apache.spark.util.collection
 
 import scala.collection.mutable.HashSet
 import org.scalatest.FunSuite
-
-class OpenHashMapSuite extends FunSuite {
+import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.util.SizeEstimator
+
+class OpenHashMapSuite extends FunSuite with ShouldMatchers {
+
+  test("size for specialized, primitive value (int)") {
+    val capacity = 1024
+    val map = new OpenHashMap[String, Int](capacity)
+    val actualSize = SizeEstimator.estimate(map)
+    // 64 bit for pointers, 32 bit for ints, and 1 bit for the bitset.
+    val expectedSize = capacity * (64 + 32 + 1) / 8
+    // Make sure we are not allocating a significant amount of memory beyond our expected.
+    actualSize should be <= (expectedSize * 1.1).toLong
+  }
 
   test("initialization") {
     val goodMap1 = new OpenHashMap[String, Int](1)
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
index 4e11e8a628b44e3dffa1b076263cfc3696eea438..4768a1e60bc31d4488923547a23c45ea57a2caae 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
@@ -1,9 +1,27 @@
 package org.apache.spark.util.collection
 
 import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
 
+import org.apache.spark.util.SizeEstimator
 
-class OpenHashSetSuite extends FunSuite {
+
+class OpenHashSetSuite extends FunSuite with ShouldMatchers {
+
+  test("size for specialized, primitive int") {
+    val loadFactor = 0.7
+    val set = new OpenHashSet[Int](64, loadFactor)
+    for (i <- 0 until 1024) {
+      set.add(i)
+    }
+    assert(set.size === 1024)
+    assert(set.capacity > 1024)
+    val actualSize = SizeEstimator.estimate(set)
+    // 32 bits for the ints + 1 bit for the bitset
+    val expectedSize = set.capacity * (32 + 1) / 8
+    // Make sure we are not allocating a significant amount of memory beyond our expected.
+    actualSize should be <= (expectedSize * 1.1).toLong
+  }
 
   test("primitive int") {
     val set = new OpenHashSet[Int]
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
similarity index 79%
rename from core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala
rename to core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
index dfd6aed2c4bccf7f1d9a25690ce0c6be41097678..2220b4f0d5efe1bf63ad1fecc5232fb704fbd09a 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
@@ -2,8 +2,20 @@ package org.apache.spark.util.collection
 
 import scala.collection.mutable.HashSet
 import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.util.SizeEstimator
 
-class PrimitiveKeyOpenHashSetSuite extends FunSuite {
+class PrimitiveKeyOpenHashMapSuite extends FunSuite with ShouldMatchers {
+
+  test("size for specialized, primitive key, value (int, int)") {
+    val capacity = 1024
+    val map = new PrimitiveKeyOpenHashMap[Int, Int](capacity)
+    val actualSize = SizeEstimator.estimate(map)
+    // 32 bit for keys, 32 bit for values, and 1 bit for the bitset.
+    val expectedSize = capacity * (32 + 32 + 1) / 8
+    // Make sure we are not allocating a significant amount of memory beyond our expected.
+    actualSize should be <= (expectedSize * 1.1).toLong
+  }
 
   test("initialization") {
     val goodMap1 = new PrimitiveKeyOpenHashMap[Int, Int](1)
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index 4056e9c15db2b625f21523576cc89af3faa12ced..68fd6c2ab1db249045307f9cb6b7effb0c57826e 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -45,6 +45,10 @@ System Properties:
 Ensure that HADOOP_CONF_DIR or YARN_CONF_DIR points to the directory which contains the (client side) configuration files for the hadoop cluster.
 This would be used to connect to the cluster, write to the dfs and submit jobs to the resource manager.
 
+There are two scheduler mode that can be used to launch spark application on YARN.
+
+## Launch spark application by YARN Client with yarn-standalone mode.
+
 The command to launch the YARN Client is as follows:
 
     SPARK_JAR=<SPARK_ASSEMBLY_JAR_FILE> ./spark-class org.apache.spark.deploy.yarn.Client \
@@ -52,6 +56,7 @@ The command to launch the YARN Client is as follows:
       --class <APP_MAIN_CLASS> \
       --args <APP_MAIN_ARGUMENTS> \
       --num-workers <NUMBER_OF_WORKER_MACHINES> \
+      --master-class <ApplicationMaster_CLASS>
       --master-memory <MEMORY_FOR_MASTER> \
       --worker-memory <MEMORY_PER_WORKER> \
       --worker-cores <CORES_PER_WORKER> \
@@ -85,11 +90,29 @@ For example:
     $ cat $YARN_APP_LOGS_DIR/$YARN_APP_ID/container*_000001/stdout
     Pi is roughly 3.13794
 
-The above starts a YARN Client programs which periodically polls the Application Master for status updates and displays them in the console. The client will exit once your application has finished running.
+The above starts a YARN Client programs which start the default Application Master. Then SparkPi will be run as a child thread of Application Master, YARN Client will  periodically polls the Application Master for status updates and displays them in the console. The client will exit once your application has finished running.
+
+With this mode, your application is actually run on the remote machine where the Application Master is run upon. Thus application that involve local interaction will not work well, e.g. spark-shell.
+
+## Launch spark application with yarn-client mode.
+
+With yarn-client mode, the application will be launched locally. Just like running application or spark-shell on Local / Mesos / Standalone mode. The launch method is also the similar with them, just make sure that when you need to specify a master url, use "yarn-client" instead. And you also need to export the env value for SPARK_JAR and SPARK_YARN_APP_JAR
+
+In order to tune worker core/number/memory etc. You need to export SPARK_WORKER_CORES, SPARK_WORKER_MEMORY, SPARK_WORKER_INSTANCES e.g. by ./conf/spark-env.sh
+
+For example:
+
+    SPARK_JAR=./assembly/target/scala-{{site.SCALA_VERSION}}/spark-assembly-{{site.SPARK_VERSION}}-hadoop2.0.5-alpha.jar \
+    SPARK_YARN_APP_JAR=examples/target/scala-{{site.SCALA_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \
+    ./run-example org.apache.spark.examples.SparkPi yarn-client
+
+
+    SPARK_JAR=./assembly/target/scala-{{site.SCALA_VERSION}}/spark-assembly-{{site.SPARK_VERSION}}-hadoop2.0.5-alpha.jar \
+    SPARK_YARN_APP_JAR=examples/target/scala-{{site.SCALA_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \
+    MASTER=yarn-client ./spark-shell
 
 # Important Notes
 
-- When your application instantiates a Spark context it must use a special "yarn-standalone" master url. This starts the scheduler without forcing it to connect to a cluster. A good way to handle this is to pass "yarn-standalone" as an argument to your program, as shown in the example above.
 - We do not requesting container resources based on the number of cores. Thus the numbers of cores given via command line arguments cannot be guaranteed.
 - The local directories used for spark will be the local directories configured for YARN (Hadoop Yarn config yarn.nodemanager.local-dirs). If the user specifies spark.local.dir, it will be ignored.
 - The --files and --archives options support specifying file names with the # similar to Hadoop. For example you can specify: --files localtest.txt#appSees.txt and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name appSees.txt and your application should use the name as appSees.txt to reference it when running on YARN.
diff --git a/docs/tuning.md b/docs/tuning.md
index f33fda37ebaba4d6b34d775bdf0190c987a3da83..a4be18816928e7fe08c24b76186af63131ed03cd 100644
--- a/docs/tuning.md
+++ b/docs/tuning.md
@@ -39,7 +39,8 @@ in your operations) and performance. It provides two serialization libraries:
   for best performance.
 
 You can switch to using Kryo by calling `System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer")`
-*before* creating your SparkContext. The only reason it is not the default is because of the custom
+*before* creating your SparkContext. This setting configures the serializer used for not only shuffling data between worker
+nodes but also when serializing RDDs to disk.  The only reason Kryo is not the default is because of the custom
 registration requirement, but we recommend trying it in any network-intensive application.
 
 Finally, to register your classes with Kryo, create a public class that extends
diff --git a/python/epydoc.conf b/python/epydoc.conf
index 1d0d002d36623409f341688811d4d846ec6ee9fe..0b42e729f8dcc756c711584de2b2a4f071b480c5 100644
--- a/python/epydoc.conf
+++ b/python/epydoc.conf
@@ -32,6 +32,6 @@ target: docs/
 
 private: no
 
-exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
+exclude: pyspark.cloudpickle pyspark.worker pyspark.join
          pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test
          pyspark.rddsampler pyspark.daemon
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index da3d96689aa15dc14707e8f88c9170e51af3cede..2204e9c9ca7011f10ca60c900896db13868c35ff 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -90,9 +90,11 @@ import struct
 import SocketServer
 import threading
 from pyspark.cloudpickle import CloudPickler
-from pyspark.serializers import read_int, read_with_length, load_pickle
+from pyspark.serializers import read_int, PickleSerializer
 
 
+pickleSer = PickleSerializer()
+
 # Holds accumulators registered on the current machine, keyed by ID. This is then used to send
 # the local accumulator updates back to the driver program at the end of a task.
 _accumulatorRegistry = {}
@@ -211,7 +213,7 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
         from pyspark.accumulators import _accumulatorRegistry
         num_updates = read_int(self.rfile)
         for _ in range(num_updates):
-            (aid, update) = load_pickle(read_with_length(self.rfile))
+            (aid, update) = pickleSer._read_with_length(self.rfile)
             _accumulatorRegistry[aid] += update
         # Write a byte in acknowledgement
         self.wfile.write(struct.pack("!b", 1))
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index a7ca8bc888c6759aff5784d26ad7df015d2fe2f4..cbd41e58c4a780392b2a6b8c58320535e416cd36 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -26,7 +26,7 @@ from pyspark.accumulators import Accumulator
 from pyspark.broadcast import Broadcast
 from pyspark.files import SparkFiles
 from pyspark.java_gateway import launch_gateway
-from pyspark.serializers import dump_pickle, write_with_length, batched
+from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer
 from pyspark.storagelevel import StorageLevel
 from pyspark.rdd import RDD
 
@@ -42,7 +42,7 @@ class SparkContext(object):
 
     _gateway = None
     _jvm = None
-    _writeIteratorToPickleFile = None
+    _writeToFile = None
     _takePartition = None
     _next_accum_id = 0
     _active_spark_context = None
@@ -51,7 +51,7 @@ class SparkContext(object):
 
 
     def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
-        environment=None, batchSize=1024):
+        environment=None, batchSize=1024, serializer=PickleSerializer()):
         """
         Create a new SparkContext.
 
@@ -67,6 +67,7 @@ class SparkContext(object):
         @param batchSize: The number of Python objects represented as a single
                Java object.  Set 1 to disable batching or -1 to use an
                unlimited batch size.
+        @param serializer: The serializer for RDDs.
 
 
         >>> from pyspark.context import SparkContext
@@ -83,7 +84,13 @@ class SparkContext(object):
         self.jobName = jobName
         self.sparkHome = sparkHome or None # None becomes null in Py4J
         self.environment = environment or {}
-        self.batchSize = batchSize  # -1 represents a unlimited batch size
+        self._batchSize = batchSize  # -1 represents an unlimited batch size
+        self._unbatched_serializer = serializer
+        if batchSize == 1:
+            self.serializer = self._unbatched_serializer
+        else:
+            self.serializer = BatchedSerializer(self._unbatched_serializer,
+                                                batchSize)
 
         # Create the Java SparkContext through Py4J
         empty_string_array = self._gateway.new_array(self._jvm.String, 0)
@@ -125,8 +132,8 @@ class SparkContext(object):
             if not SparkContext._gateway:
                 SparkContext._gateway = launch_gateway()
                 SparkContext._jvm = SparkContext._gateway.jvm
-                SparkContext._writeIteratorToPickleFile = \
-                    SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
+                SparkContext._writeToFile = \
+                    SparkContext._jvm.PythonRDD.writeToFile
                 SparkContext._takePartition = \
                     SparkContext._jvm.PythonRDD.takePartition
 
@@ -184,15 +191,17 @@ class SparkContext(object):
         # Make sure we distribute data evenly if it's smaller than self.batchSize
         if "__len__" not in dir(c):
             c = list(c)    # Make it a list so we can compute its length
-        batchSize = min(len(c) // numSlices, self.batchSize)
+        batchSize = min(len(c) // numSlices, self._batchSize)
         if batchSize > 1:
-            c = batched(c, batchSize)
-        for x in c:
-            write_with_length(dump_pickle(x), tempFile)
+            serializer = BatchedSerializer(self._unbatched_serializer,
+                                           batchSize)
+        else:
+            serializer = self._unbatched_serializer
+        serializer.dump_stream(c, tempFile)
         tempFile.close()
-        readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
-        jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
-        return RDD(jrdd, self)
+        readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
+        jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
+        return RDD(jrdd, self, serializer)
 
     def textFile(self, name, minSplits=None):
         """
@@ -201,21 +210,39 @@ class SparkContext(object):
         RDD of Strings.
         """
         minSplits = minSplits or min(self.defaultParallelism, 2)
-        jrdd = self._jsc.textFile(name, minSplits)
-        return RDD(jrdd, self)
+        return RDD(self._jsc.textFile(name, minSplits), self,
+                   MUTF8Deserializer())
 
-    def _checkpointFile(self, name):
+    def _checkpointFile(self, name, input_deserializer):
         jrdd = self._jsc.checkpointFile(name)
-        return RDD(jrdd, self)
+        return RDD(jrdd, self, input_deserializer)
 
     def union(self, rdds):
         """
         Build the union of a list of RDDs.
+
+        This supports unions() of RDDs with different serialized formats,
+        although this forces them to be reserialized using the default
+        serializer:
+
+        >>> path = os.path.join(tempdir, "union-text.txt")
+        >>> with open(path, "w") as testFile:
+        ...    testFile.write("Hello")
+        >>> textFile = sc.textFile(path)
+        >>> textFile.collect()
+        [u'Hello']
+        >>> parallelized = sc.parallelize(["World!"])
+        >>> sorted(sc.union([textFile, parallelized]).collect())
+        [u'Hello', 'World!']
         """
+        first_jrdd_deserializer = rdds[0]._jrdd_deserializer
+        if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
+            rdds = [x._reserialize() for x in rdds]
         first = rdds[0]._jrdd
         rest = [x._jrdd for x in rdds[1:]]
-        rest = ListConverter().convert(rest, self.gateway._gateway_client)
-        return RDD(self._jsc.union(first, rest), self)
+        rest = ListConverter().convert(rest, self._gateway._gateway_client)
+        return RDD(self._jsc.union(first, rest), self,
+                   rdds[0]._jrdd_deserializer)
 
     def broadcast(self, value):
         """
@@ -223,7 +250,9 @@ class SparkContext(object):
         object for reading it in distributed functions. The variable will be
         sent to each cluster only once.
         """
-        jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
+        pickleSer = PickleSerializer()
+        pickled = pickleSer.dumps(value)
+        jbroadcast = self._jsc.broadcast(bytearray(pickled))
         return Broadcast(jbroadcast.id(), value, jbroadcast,
                          self._pickled_broadcast_vars)
 
@@ -235,7 +264,7 @@ class SparkContext(object):
         and floating-point numbers if you do not provide one. For other types,
         a custom AccumulatorParam can be used.
         """
-        if accum_param == None:
+        if accum_param is None:
             if isinstance(value, int):
                 accum_param = accumulators.INT_ACCUMULATOR_PARAM
             elif isinstance(value, float):
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 245a132dfdcecfc88b3d7a9e85b6cc09409d1671..d2cb5f191aab1240e1acd7ead29083f976a969e3 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -18,7 +18,7 @@
 from base64 import standard_b64encode as b64enc
 import copy
 from collections import defaultdict
-from itertools import chain, ifilter, imap, product
+from itertools import chain, ifilter, imap
 import operator
 import os
 import sys
@@ -27,9 +27,8 @@ from subprocess import Popen, PIPE
 from tempfile import NamedTemporaryFile
 from threading import Thread
 
-from pyspark import cloudpickle
-from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
-    read_from_pickle_file, pack_long
+from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
+    BatchedSerializer, CloudPickleSerializer, pack_long
 from pyspark.join import python_join, python_left_outer_join, \
     python_right_outer_join, python_cogroup
 from pyspark.statcounter import StatCounter
@@ -48,12 +47,12 @@ class RDD(object):
     operated on in parallel.
     """
 
-    def __init__(self, jrdd, ctx):
+    def __init__(self, jrdd, ctx, jrdd_deserializer):
         self._jrdd = jrdd
         self.is_cached = False
         self.is_checkpointed = False
         self.ctx = ctx
-        self._partitionFunc = None
+        self._jrdd_deserializer = jrdd_deserializer
 
     @property
     def context(self):
@@ -247,7 +246,23 @@ class RDD(object):
         >>> rdd.union(rdd).collect()
         [1, 1, 2, 3, 1, 1, 2, 3]
         """
-        return RDD(self._jrdd.union(other._jrdd), self.ctx)
+        if self._jrdd_deserializer == other._jrdd_deserializer:
+            rdd = RDD(self._jrdd.union(other._jrdd), self.ctx,
+                      self._jrdd_deserializer)
+            return rdd
+        else:
+            # These RDDs contain data in different serialized formats, so we
+            # must normalize them to the default serializer.
+            self_copy = self._reserialize()
+            other_copy = other._reserialize()
+            return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
+                       self.ctx.serializer)
+
+    def _reserialize(self):
+        if self._jrdd_deserializer == self.ctx.serializer:
+            return self
+        else:
+            return self.map(lambda x: x, preservesPartitioning=True)
 
     def __add__(self, other):
         """
@@ -334,17 +349,9 @@ class RDD(object):
         [(1, 1), (1, 2), (2, 1), (2, 2)]
         """
         # Due to batching, we can't use the Java cartesian method.
-        java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
-        def unpack_batches(pair):
-            (x, y) = pair
-            if type(x) == Batch or type(y) == Batch:
-                xs = x.items if type(x) == Batch else [x]
-                ys = y.items if type(y) == Batch else [y]
-                for pair in product(xs, ys):
-                    yield pair
-            else:
-                yield pair
-        return java_cartesian.flatMap(unpack_batches)
+        deserializer = CartesianDeserializer(self._jrdd_deserializer,
+                                             other._jrdd_deserializer)
+        return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer)
 
     def groupBy(self, f, numPartitions=None):
         """
@@ -391,8 +398,8 @@ class RDD(object):
         """
         Return a list that contains all of the elements in this RDD.
         """
-        picklesInJava = self._jrdd.collect().iterator()
-        return list(self._collect_iterator_through_file(picklesInJava))
+        bytesInJava = self._jrdd.collect().iterator()
+        return list(self._collect_iterator_through_file(bytesInJava))
 
     def _collect_iterator_through_file(self, iterator):
         # Transferring lots of data through Py4J can be slow because
@@ -400,10 +407,10 @@ class RDD(object):
         # file and read it back.
         tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
         tempFile.close()
-        self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
+        self.ctx._writeToFile(iterator, tempFile.name)
         # Read the data into Python and deserialize it:
         with open(tempFile.name, 'rb') as tempFile:
-            for item in read_from_pickle_file(tempFile):
+            for item in self._jrdd_deserializer.load_stream(tempFile):
                 yield item
         os.unlink(tempFile.name)
 
@@ -571,7 +578,7 @@ class RDD(object):
         items = []
         for partition in range(mapped._jrdd.splits().size()):
             iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
-            items.extend(self._collect_iterator_through_file(iterator))
+            items.extend(mapped._collect_iterator_through_file(iterator))
             if len(items) >= num:
                 break
         return items[:num]
@@ -735,6 +742,7 @@ class RDD(object):
         # Transferring O(n) objects to Java is too expensive.  Instead, we'll
         # form the hash buckets in Python, transferring O(numPartitions) objects
         # to Java.  Each object is a (splitNumber, [objects]) pair.
+        outputSerializer = self.ctx._unbatched_serializer
         def add_shuffle_key(split, iterator):
 
             buckets = defaultdict(list)
@@ -743,14 +751,14 @@ class RDD(object):
                 buckets[partitionFunc(k) % numPartitions].append((k, v))
             for (split, items) in buckets.iteritems():
                 yield pack_long(split)
-                yield dump_pickle(Batch(items))
+                yield outputSerializer.dumps(items)
         keyed = PipelinedRDD(self, add_shuffle_key)
         keyed._bypass_serializer = True
         pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
         partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
                                                      id(partitionFunc))
         jrdd = pairRDD.partitionBy(partitioner).values()
-        rdd = RDD(jrdd, self.ctx)
+        rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
         # This is required so that id(partitionFunc) remains unique, even if
         # partitionFunc is a lambda:
         rdd._partitionFunc = partitionFunc
@@ -787,7 +795,8 @@ class RDD(object):
             numPartitions = self.ctx.defaultParallelism
         def combineLocally(iterator):
             combiners = {}
-            for (k, v) in iterator:
+            for x in iterator:
+                (k, v) = x
                 if k not in combiners:
                     combiners[k] = createCombiner(v)
                 else:
@@ -929,38 +938,39 @@ class PipelinedRDD(RDD):
     20
     """
     def __init__(self, prev, func, preservesPartitioning=False):
-        if isinstance(prev, PipelinedRDD) and prev._is_pipelinable():
+        if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable():
+            # This transformation is the first in its stage:
+            self.func = func
+            self.preservesPartitioning = preservesPartitioning
+            self._prev_jrdd = prev._jrdd
+            self._prev_jrdd_deserializer = prev._jrdd_deserializer
+        else:
             prev_func = prev.func
             def pipeline_func(split, iterator):
                 return func(split, prev_func(split, iterator))
             self.func = pipeline_func
             self.preservesPartitioning = \
                 prev.preservesPartitioning and preservesPartitioning
-            self._prev_jrdd = prev._prev_jrdd
-        else:
-            self.func = func
-            self.preservesPartitioning = preservesPartitioning
-            self._prev_jrdd = prev._jrdd
+            self._prev_jrdd = prev._prev_jrdd  # maintain the pipeline
+            self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer
         self.is_cached = False
         self.is_checkpointed = False
         self.ctx = prev.ctx
         self.prev = prev
         self._jrdd_val = None
+        self._jrdd_deserializer = self.ctx.serializer
         self._bypass_serializer = False
 
     @property
     def _jrdd(self):
         if self._jrdd_val:
             return self._jrdd_val
-        func = self.func
-        if not self._bypass_serializer and self.ctx.batchSize != 1:
-            oldfunc = self.func
-            batchSize = self.ctx.batchSize
-            def batched_func(split, iterator):
-                return batched(oldfunc(split, iterator), batchSize)
-            func = batched_func
-        cmds = [func, self._bypass_serializer]
-        pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
+        if self._bypass_serializer:
+            serializer = NoOpSerializer()
+        else:
+            serializer = self.ctx.serializer
+        command = (self.func, self._prev_jrdd_deserializer, serializer)
+        pickled_command = CloudPickleSerializer().dumps(command)
         broadcast_vars = ListConverter().convert(
             [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
             self.ctx._gateway._gateway_client)
@@ -971,8 +981,9 @@ class PipelinedRDD(RDD):
         includes = ListConverter().convert(self.ctx._python_includes,
                                      self.ctx._gateway._gateway_client)
         python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
-            pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec,
-            broadcast_vars, self.ctx._javaAccumulator, class_tag)
+            bytearray(pickled_command), env, includes, self.preservesPartitioning,
+            self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator,
+            class_tag)
         self._jrdd_val = python_rdd.asJavaRDD()
         return self._jrdd_val
 
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 54fed1c9c70f66e503abb5c523d6327bb9bae8b4..811fa6f018b23f3c9883bd2a770f03c044786850 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -15,45 +15,269 @@
 # limitations under the License.
 #
 
-import struct
+"""
+PySpark supports custom serializers for transferring data; this can improve
+performance.
+
+By default, PySpark uses L{PickleSerializer} to serialize objects using Python's
+C{cPickle} serializer, which can serialize nearly any Python object.
+Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be
+faster.
+
+The serializer is chosen when creating L{SparkContext}:
+
+>>> from pyspark.context import SparkContext
+>>> from pyspark.serializers import MarshalSerializer
+>>> sc = SparkContext('local', 'test', serializer=MarshalSerializer())
+>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
+[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
+>>> sc.stop()
+
+By default, PySpark serialize objects in batches; the batch size can be
+controlled through SparkContext's C{batchSize} parameter
+(the default size is 1024 objects):
+
+>>> sc = SparkContext('local', 'test', batchSize=2)
+>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
+
+Behind the scenes, this creates a JavaRDD with four partitions, each of
+which contains two batches of two objects:
+
+>>> rdd.glom().collect()
+[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
+>>> rdd._jrdd.count()
+8L
+>>> sc.stop()
+
+A batch size of -1 uses an unlimited batch size, and a size of 1 disables
+batching:
+
+>>> sc = SparkContext('local', 'test', batchSize=1)
+>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
+>>> rdd.glom().collect()
+[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
+>>> rdd._jrdd.count()
+16L
+"""
+
 import cPickle
+from itertools import chain, izip, product
+import marshal
+import struct
+from pyspark import cloudpickle
+
+
+__all__ = ["PickleSerializer", "MarshalSerializer"]
+
+
+class SpecialLengths(object):
+    END_OF_DATA_SECTION = -1
+    PYTHON_EXCEPTION_THROWN = -2
+    TIMING_DATA = -3
+
+
+class Serializer(object):
+
+    def dump_stream(self, iterator, stream):
+        """
+        Serialize an iterator of objects to the output stream.
+        """
+        raise NotImplementedError
+
+    def load_stream(self, stream):
+        """
+        Return an iterator of deserialized objects from the input stream.
+        """
+        raise NotImplementedError
+
+
+    def _load_stream_without_unbatching(self, stream):
+        return self.load_stream(stream)
+
+    # Note: our notion of "equality" is that output generated by
+    # equal serializers can be deserialized using the same serializer.
+
+    # This default implementation handles the simple cases;
+    # subclasses should override __eq__ as appropriate.
+
+    def __eq__(self, other):
+        return isinstance(other, self.__class__)
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+
+class FramedSerializer(Serializer):
+    """
+    Serializer that writes objects as a stream of (length, data) pairs,
+    where C{length} is a 32-bit integer and data is C{length} bytes.
+    """
+
+    def dump_stream(self, iterator, stream):
+        for obj in iterator:
+            self._write_with_length(obj, stream)
+
+    def load_stream(self, stream):
+        while True:
+            try:
+                yield self._read_with_length(stream)
+            except EOFError:
+                return
+
+    def _write_with_length(self, obj, stream):
+        serialized = self.dumps(obj)
+        write_int(len(serialized), stream)
+        stream.write(serialized)
+
+    def _read_with_length(self, stream):
+        length = read_int(stream)
+        obj = stream.read(length)
+        if obj == "":
+            raise EOFError
+        return self.loads(obj)
+
+    def dumps(self, obj):
+        """
+        Serialize an object into a byte array.
+        When batching is used, this will be called with an array of objects.
+        """
+        raise NotImplementedError
+
+    def loads(self, obj):
+        """
+        Deserialize an object from a byte array.
+        """
+        raise NotImplementedError
+
+
+class BatchedSerializer(Serializer):
+    """
+    Serializes a stream of objects in batches by calling its wrapped
+    Serializer with streams of objects.
+    """
+
+    UNLIMITED_BATCH_SIZE = -1
+
+    def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE):
+        self.serializer = serializer
+        self.batchSize = batchSize
+
+    def _batched(self, iterator):
+        if self.batchSize == self.UNLIMITED_BATCH_SIZE:
+            yield list(iterator)
+        else:
+            items = []
+            count = 0
+            for item in iterator:
+                items.append(item)
+                count += 1
+                if count == self.batchSize:
+                    yield items
+                    items = []
+                    count = 0
+            if items:
+                yield items
+
+    def dump_stream(self, iterator, stream):
+        self.serializer.dump_stream(self._batched(iterator), stream)
+
+    def load_stream(self, stream):
+        return chain.from_iterable(self._load_stream_without_unbatching(stream))
+
+    def _load_stream_without_unbatching(self, stream):
+            return self.serializer.load_stream(stream)
+
+    def __eq__(self, other):
+        return isinstance(other, BatchedSerializer) and \
+               other.serializer == self.serializer
+
+    def __str__(self):
+        return "BatchedSerializer<%s>" % str(self.serializer)
 
 
-class Batch(object):
+class CartesianDeserializer(FramedSerializer):
     """
-    Used to store multiple RDD entries as a single Java object.
+    Deserializes the JavaRDD cartesian() of two PythonRDDs.
+    """
+
+    def __init__(self, key_ser, val_ser):
+        self.key_ser = key_ser
+        self.val_ser = val_ser
+
+    def load_stream(self, stream):
+        key_stream = self.key_ser._load_stream_without_unbatching(stream)
+        val_stream = self.val_ser._load_stream_without_unbatching(stream)
+        key_is_batched = isinstance(self.key_ser, BatchedSerializer)
+        val_is_batched = isinstance(self.val_ser, BatchedSerializer)
+        for (keys, vals) in izip(key_stream, val_stream):
+            keys = keys if key_is_batched else [keys]
+            vals = vals if val_is_batched else [vals]
+            for pair in product(keys, vals):
+                yield pair
+
+    def __eq__(self, other):
+        return isinstance(other, CartesianDeserializer) and \
+               self.key_ser == other.key_ser and self.val_ser == other.val_ser
+
+    def __str__(self):
+        return "CartesianDeserializer<%s, %s>" % \
+               (str(self.key_ser), str(self.val_ser))
 
-    This relieves us from having to explicitly track whether an RDD
-    is stored as batches of objects and avoids problems when processing
-    the union() of batched and unbatched RDDs (e.g. the union() of textFile()
-    with another RDD).
+
+class NoOpSerializer(FramedSerializer):
+
+    def loads(self, obj): return obj
+    def dumps(self, obj): return obj
+
+
+class PickleSerializer(FramedSerializer):
     """
-    def __init__(self, items):
-        self.items = items
+    Serializes objects using Python's cPickle serializer:
 
+        http://docs.python.org/2/library/pickle.html
 
-def batched(iterator, batchSize):
-    if batchSize == -1: # unlimited batch size
-        yield Batch(list(iterator))
-    else:
-        items = []
-        count = 0
-        for item in iterator:
-            items.append(item)
-            count += 1
-            if count == batchSize:
-                yield Batch(items)
-                items = []
-                count = 0
-        if items:
-            yield Batch(items)
+    This serializer supports nearly any Python object, but may
+    not be as fast as more specialized serializers.
+    """
 
+    def dumps(self, obj): return cPickle.dumps(obj, 2)
+    loads = cPickle.loads
 
-def dump_pickle(obj):
-    return cPickle.dumps(obj, 2)
+class CloudPickleSerializer(PickleSerializer):
 
+    def dumps(self, obj): return cloudpickle.dumps(obj, 2)
 
-load_pickle = cPickle.loads
+
+class MarshalSerializer(FramedSerializer):
+    """
+    Serializes objects using Python's Marshal serializer:
+
+        http://docs.python.org/2/library/marshal.html
+
+    This serializer is faster than PickleSerializer but supports fewer datatypes.
+    """
+
+    dumps = marshal.dumps
+    loads = marshal.loads
+
+
+class MUTF8Deserializer(Serializer):
+    """
+    Deserializes streams written by Java's DataOutputStream.writeUTF().
+    """
+
+    def loads(self, stream):
+        length = struct.unpack('>H', stream.read(2))[0]
+        return stream.read(length).decode('utf8')
+
+    def load_stream(self, stream):
+        while True:
+            try:
+                yield self.loads(stream)
+            except struct.error:
+                return
+            except EOFError:
+                return
 
 
 def read_long(stream):
@@ -84,25 +308,4 @@ def write_int(value, stream):
 
 def write_with_length(obj, stream):
     write_int(len(obj), stream)
-    stream.write(obj)
-
-
-def read_with_length(stream):
-    length = read_int(stream)
-    obj = stream.read(length)
-    if obj == "":
-        raise EOFError
-    return obj
-
-
-def read_from_pickle_file(stream):
-    try:
-        while True:
-            obj = load_pickle(read_with_length(stream))
-            if type(obj) == Batch:  # We don't care about inheritance
-                for item in obj.items:
-                    yield item
-            else:
-                yield obj
-    except EOFError:
-        return
+    stream.write(obj)
\ No newline at end of file
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 29d6a128f6a9b83dc742c676ef010a90f54ab73e..621e1cb58c3df10afa1f64a8f7b9f988dd71b0cb 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -86,7 +86,8 @@ class TestCheckpoint(PySparkTestCase):
         time.sleep(1)  # 1 second
 
         self.assertTrue(flatMappedRDD.getCheckpointFile() is not None)
-        recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
+        recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(),
+                                            flatMappedRDD._jrdd_deserializer)
         self.assertEquals([1, 2, 3, 4], recovered.collect())
 
 
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index d63c2aaef772de62eef3bf913ad4a4859cf30512..f2b3f3c1421d12d48637ca47b3ac39ed61bbcfac 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -23,23 +23,22 @@ import sys
 import time
 import socket
 import traceback
-from base64 import standard_b64decode
 # CloudPickler needs to be imported so that depicklers are registered using the
 # copy_reg module.
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.broadcast import Broadcast, _broadcastRegistry
 from pyspark.cloudpickle import CloudPickler
 from pyspark.files import SparkFiles
-from pyspark.serializers import write_with_length, read_with_length, write_int, \
-    read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
+from pyspark.serializers import write_with_length, write_int, read_long, \
+    write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer
 
 
-def load_obj(infile):
-    return load_pickle(standard_b64decode(infile.readline().strip()))
+pickleSer = PickleSerializer()
+mutf8_deserializer = MUTF8Deserializer()
 
 
 def report_times(outfile, boot, init, finish):
-    write_int(-3, outfile)
+    write_int(SpecialLengths.TIMING_DATA, outfile)
     write_long(1000 * boot, outfile)
     write_long(1000 * init, outfile)
     write_long(1000 * finish, outfile)
@@ -52,7 +51,7 @@ def main(infile, outfile):
         return
 
     # fetch name of workdir
-    spark_files_dir = load_pickle(read_with_length(infile))
+    spark_files_dir = mutf8_deserializer.loads(infile)
     SparkFiles._root_directory = spark_files_dir
     SparkFiles._is_running_on_worker = True
 
@@ -60,38 +59,33 @@ def main(infile, outfile):
     num_broadcast_variables = read_int(infile)
     for _ in range(num_broadcast_variables):
         bid = read_long(infile)
-        value = read_with_length(infile)
-        _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
+        value = pickleSer._read_with_length(infile)
+        _broadcastRegistry[bid] = Broadcast(bid, value)
 
     # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
     sys.path.append(spark_files_dir) # *.py files that were added will be copied here
     num_python_includes =  read_int(infile)
     for _ in range(num_python_includes):
-        sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile))))
+        filename = mutf8_deserializer.loads(infile)
+        sys.path.append(os.path.join(spark_files_dir, filename))
 
-    # now load function
-    func = load_obj(infile)
-    bypassSerializer = load_obj(infile)
-    if bypassSerializer:
-        dumps = lambda x: x
-    else:
-        dumps = dump_pickle
+    command = pickleSer._read_with_length(infile)
+    (func, deserializer, serializer) = command
     init_time = time.time()
-    iterator = read_from_pickle_file(infile)
     try:
-        for obj in func(split_index, iterator):
-            write_with_length(dumps(obj), outfile)
+        iterator = deserializer.load_stream(infile)
+        serializer.dump_stream(func(split_index, iterator), outfile)
     except Exception as e:
-        write_int(-2, outfile)
+        write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
         write_with_length(traceback.format_exc(), outfile)
         sys.exit(-1)
     finish_time = time.time()
     report_times(outfile, boot_time, init_time, finish_time)
     # Mark the beginning of the accumulators section of the output
-    write_int(-1, outfile)
-    for aid, accum in _accumulatorRegistry.items():
-        write_with_length(dump_pickle((aid, accum._value)), outfile)
-    write_int(-1, outfile)
+    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
+    write_int(len(_accumulatorRegistry), outfile)
+    for (aid, accum) in _accumulatorRegistry.items():
+        pickleSer._write_with_length((aid, accum._value), outfile)
 
 
 if __name__ == '__main__':
diff --git a/python/run-tests b/python/run-tests
index cbc554ea9db0d2cdd18323a408879d98f80810cb..d4dad672d284299a0a1fc116b24f7833eefe1b7b 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -37,6 +37,7 @@ run_test "pyspark/rdd.py"
 run_test "pyspark/context.py"
 run_test "-m doctest pyspark/broadcast.py"
 run_test "-m doctest pyspark/accumulators.py"
+run_test "-m doctest pyspark/serializers.py"
 run_test "pyspark/tests.py"
 
 if [[ $FAILED != 0 ]]; then
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 94e353af2e954f5e56970375b56e4dc60f1eeb5d..bb73f6d337ba04f47bd407c24b6311f0a50e10b6 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -54,9 +54,10 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
   // staging directory is private! -> rwx--------
   val STAGING_DIR_PERMISSION: FsPermission = FsPermission.createImmutable(0700:Short)
   // app files are world-wide readable and owner writable -> rw-r--r--
-  val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644:Short) 
+  val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644:Short)
 
-  def run() {
+  // for client user who want to monitor app status by itself.
+  def runApp() = {
     validateArgs()
 
     init(yarnConf)
@@ -78,7 +79,11 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
     appContext.setUser(UserGroupInformation.getCurrentUser().getShortUserName())
 
     submitApp(appContext)
-    
+    appId
+  }
+
+  def run() {
+    val appId = runApp()
     monitorApplication(appId)
     System.exit(0)
   }
@@ -372,7 +377,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
     val commands = List[String](javaCommand + 
       " -server " +
       JAVA_OPTS +
-      " org.apache.spark.deploy.yarn.ApplicationMaster" +
+      " " + args.amClass +
       " --class " + args.userClass + 
       " --jar " + args.userJar +
       userArgsToString(args) +
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
index 852dbd7dabf66e391c4e64584670cdf0a74a853f..b9dbc3fb87a1f4ba4e9a37650bee52ae88e0a16c 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -35,6 +35,7 @@ class ClientArguments(val args: Array[String]) {
   var numWorkers = 2
   var amQueue = System.getProperty("QUEUE", "default")
   var amMemory: Int = 512
+  var amClass: String = "org.apache.spark.deploy.yarn.ApplicationMaster"
   var appName: String = "Spark"
   // TODO
   var inputFormatInfo: List[InputFormatInfo] = null
@@ -62,18 +63,22 @@ class ClientArguments(val args: Array[String]) {
           userArgsBuffer += value
           args = tail
 
-        case ("--master-memory") :: MemoryParam(value) :: tail =>
-          amMemory = value
+        case ("--master-class") :: value :: tail =>
+          amClass = value
           args = tail
 
-        case ("--num-workers") :: IntParam(value) :: tail =>
-          numWorkers = value
+        case ("--master-memory") :: MemoryParam(value) :: tail =>
+          amMemory = value
           args = tail
 
         case ("--worker-memory") :: MemoryParam(value) :: tail =>
           workerMemory = value
           args = tail
 
+        case ("--num-workers") :: IntParam(value) :: tail =>
+          numWorkers = value
+          args = tail
+
         case ("--worker-cores") :: IntParam(value) :: tail =>
           workerCores = value
           args = tail
@@ -119,19 +124,20 @@ class ClientArguments(val args: Array[String]) {
     System.err.println(
       "Usage: org.apache.spark.deploy.yarn.Client [options] \n" +
       "Options:\n" +
-      "  --jar JAR_PATH       Path to your application's JAR file (required)\n" +
-      "  --class CLASS_NAME   Name of your application's main class (required)\n" +
-      "  --args ARGS          Arguments to be passed to your application's main class.\n" +
-      "                       Mutliple invocations are possible, each will be passed in order.\n" +
-      "  --num-workers NUM    Number of workers to start (Default: 2)\n" +
-      "  --worker-cores NUM   Number of cores for the workers (Default: 1). This is unsused right now.\n" +
-      "  --master-memory MEM  Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" +
-      "  --worker-memory MEM  Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" +
-      "  --name NAME          The name of your application (Default: Spark)\n" +
-      "  --queue QUEUE        The hadoop queue to use for allocation requests (Default: 'default')\n" +
-      "  --addJars jars       Comma separated list of local jars that want SparkContext.addJar to work with.\n" +
-      "  --files files        Comma separated list of files to be distributed with the job.\n" +
-      "  --archives archives  Comma separated list of archives to be distributed with the job."
+      "  --jar JAR_PATH             Path to your application's JAR file (required)\n" +
+      "  --class CLASS_NAME         Name of your application's main class (required)\n" +
+      "  --args ARGS                Arguments to be passed to your application's main class.\n" +
+      "                             Mutliple invocations are possible, each will be passed in order.\n" +
+      "  --num-workers NUM          Number of workers to start (Default: 2)\n" +
+      "  --worker-cores NUM         Number of cores for the workers (Default: 1). This is unsused right now.\n" +
+      "  --master-class CLASS_NAME  Class Name for Master (Default: spark.deploy.yarn.ApplicationMaster)\n" +
+      "  --master-memory MEM        Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" +
+      "  --worker-memory MEM        Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" +
+      "  --name NAME                The name of your application (Default: Spark)\n" +
+      "  --queue QUEUE              The hadoop queue to use for allocation requests (Default: 'default')\n" +
+      "  --addJars jars             Comma separated list of local jars that want SparkContext.addJar to work with.\n" +
+      "  --files files              Comma separated list of files to be distributed with the job.\n" +
+      "  --archives archives        Comma separated list of archives to be distributed with the job."
       )
     System.exit(exitCode)
   }
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
new file mode 100644
index 0000000000000000000000000000000000000000..421a83c87afdf156b19bee8a006f0dde11fa4c2a
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.Socket
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
+import akka.actor._
+import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent}
+import akka.remote.RemoteClientShutdown
+import akka.actor.Terminated
+import akka.remote.RemoteClientDisconnected
+import org.apache.spark.{SparkContext, Logging}
+import org.apache.spark.util.{Utils, AkkaUtils}
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
+import org.apache.spark.scheduler.SplitInfo
+
+class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration) extends Logging {
+
+  def this(args: ApplicationMasterArguments) = this(args, new Configuration())
+
+  private val rpc: YarnRPC = YarnRPC.create(conf)
+  private var resourceManager: AMRMProtocol = null
+  private var appAttemptId: ApplicationAttemptId = null
+  private var reporterThread: Thread = null
+  private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+
+  private var yarnAllocator: YarnAllocationHandler = null
+  private var driverClosed:Boolean = false
+
+  val actorSystem : ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0)._1
+  var actor: ActorRef = null
+
+  // This actor just working as a monitor to watch on Driver Actor.
+  class MonitorActor(driverUrl: String) extends Actor {
+
+    var driver: ActorRef = null
+
+    override def preStart() {
+      logInfo("Listen to driver: " + driverUrl)
+      driver = context.actorFor(driverUrl)
+      driver ! "hello"
+      context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+      context.watch(driver) // Doesn't work with remote actors, but useful for testing
+    }
+
+    override def receive = {
+      case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
+        logInfo("Driver terminated or disconnected! Shutting down.")
+        driverClosed = true
+    }
+  }
+
+  def run() {
+
+    appAttemptId = getApplicationAttemptId()
+    resourceManager = registerWithResourceManager()
+    val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster()
+
+    // Compute number of threads for akka
+    val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory()
+
+    if (minimumMemory > 0) {
+      val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD
+      val numCore = (mem  / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0)
+
+      if (numCore > 0) {
+        // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406
+        // TODO: Uncomment when hadoop is on a version which has this fixed.
+        // args.workerCores = numCore
+      }
+    }
+
+    waitForSparkMaster()
+
+    // Allocate all containers
+    allocateWorkers()
+
+    // Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout
+    // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse.
+
+    val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
+    // must be <= timeoutInterval/ 2.
+    // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM.
+    // so atleast 1 minute or timeoutInterval / 10 - whichever is higher.
+    val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L))
+    reporterThread = launchReporterThread(interval)
+
+    // Wait for the reporter thread to Finish.
+    reporterThread.join()
+
+    finishApplicationMaster(FinalApplicationStatus.SUCCEEDED)
+    actorSystem.shutdown()
+
+    logInfo("Exited")
+    System.exit(0)
+  }
+
+  private def getApplicationAttemptId(): ApplicationAttemptId = {
+    val envs = System.getenv()
+    val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV)
+    val containerId = ConverterUtils.toContainerId(containerIdString)
+    val appAttemptId = containerId.getApplicationAttemptId()
+    logInfo("ApplicationAttemptId: " + appAttemptId)
+    return appAttemptId
+  }
+
+  private def registerWithResourceManager(): AMRMProtocol = {
+    val rmAddress = NetUtils.createSocketAddr(yarnConf.get(
+      YarnConfiguration.RM_SCHEDULER_ADDRESS,
+      YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS))
+    logInfo("Connecting to ResourceManager at " + rmAddress)
+    return rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol]
+  }
+
+  private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
+    logInfo("Registering the ApplicationMaster")
+    val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest])
+      .asInstanceOf[RegisterApplicationMasterRequest]
+    appMasterRequest.setApplicationAttemptId(appAttemptId)
+    // Setting this to master host,port - so that the ApplicationReport at client has some sensible info.
+    // Users can then monitor stderr/stdout on that node if required.
+    appMasterRequest.setHost(Utils.localHostName())
+    appMasterRequest.setRpcPort(0)
+    // What do we provide here ? Might make sense to expose something sensible later ?
+    appMasterRequest.setTrackingUrl("")
+    return resourceManager.registerApplicationMaster(appMasterRequest)
+  }
+
+  private def waitForSparkMaster() {
+    logInfo("Waiting for spark driver to be reachable.")
+    var driverUp = false
+    val hostport = args.userArgs(0)
+    val (driverHost, driverPort) = Utils.parseHostPort(hostport)
+    while(!driverUp) {
+      try {
+        val socket = new Socket(driverHost, driverPort)
+        socket.close()
+        logInfo("Master now available: " + driverHost + ":" + driverPort)
+        driverUp = true
+      } catch {
+        case e: Exception =>
+          logError("Failed to connect to driver at " + driverHost + ":" + driverPort)
+        Thread.sleep(100)
+      }
+    }
+    System.setProperty("spark.driver.host", driverHost)
+    System.setProperty("spark.driver.port", driverPort.toString)
+
+    val driverUrl = "akka://spark@%s:%s/user/%s".format(
+      driverHost, driverPort.toString, CoarseGrainedSchedulerBackend.ACTOR_NAME)
+
+    actor = actorSystem.actorOf(Props(new MonitorActor(driverUrl)), name = "YarnAM")
+  }
+
+
+  private def allocateWorkers() {
+
+    // Fixme: should get preferredNodeLocationData from SparkContext, just fake a empty one for now.
+    val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map()
+
+    yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args, preferredNodeLocationData)
+
+    logInfo("Allocating " + args.numWorkers + " workers.")
+    // Wait until all containers have finished
+    // TODO: This is a bit ugly. Can we make it nicer?
+    // TODO: Handle container failure
+    while(yarnAllocator.getNumWorkersRunning < args.numWorkers) {
+      yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0))
+      Thread.sleep(100)
+    }
+
+    logInfo("All workers have launched.")
+
+  }
+
+  // TODO: We might want to extend this to allocate more containers in case they die !
+  private def launchReporterThread(_sleepTime: Long): Thread = {
+    val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime
+
+    val t = new Thread {
+      override def run() {
+        while (!driverClosed) {
+          val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning
+          if (missingWorkerCount > 0) {
+            logInfo("Allocating " + missingWorkerCount + " containers to make up for (potentially ?) lost containers")
+            yarnAllocator.allocateContainers(missingWorkerCount)
+          }
+          else sendProgress()
+          Thread.sleep(sleepTime)
+        }
+      }
+    }
+    // setting to daemon status, though this is usually not a good idea.
+    t.setDaemon(true)
+    t.start()
+    logInfo("Started progress reporter thread - sleep time : " + sleepTime)
+    return t
+  }
+
+  private def sendProgress() {
+    logDebug("Sending progress")
+    // simulated with an allocate request with no nodes requested ...
+    yarnAllocator.allocateContainers(0)
+  }
+
+  def finishApplicationMaster(status: FinalApplicationStatus) {
+
+    logInfo("finish ApplicationMaster with " + status)
+    val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest])
+      .asInstanceOf[FinishApplicationMasterRequest]
+    finishReq.setAppAttemptId(appAttemptId)
+    finishReq.setFinishApplicationStatus(status)
+    resourceManager.finishApplicationMaster(finishReq)
+  }
+
+}
+
+
+object WorkerLauncher {
+  def main(argStrings: Array[String]) {
+    val args = new ApplicationMasterArguments(argStrings)
+    new WorkerLauncher(args).run()
+  }
+}
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
new file mode 100644
index 0000000000000000000000000000000000000000..63a0449e5a0730085554d2b8ae86067135fa8dba
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark._
+import org.apache.hadoop.conf.Configuration
+import org.apache.spark.deploy.yarn.YarnAllocationHandler
+import org.apache.spark.util.Utils
+
+/**
+ *
+ * This scheduler launch worker through Yarn - by call into Client to launch WorkerLauncher as AM.
+ */
+private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
+
+  def this(sc: SparkContext) = this(sc, new Configuration())
+
+  // By default, rack is unknown
+  override def getRackForHost(hostPort: String): Option[String] = {
+    val host = Utils.parseHostPort(hostPort)._1
+    val retval = YarnAllocationHandler.lookupRack(conf, host)
+    if (retval != null) Some(retval) else None
+  }
+
+  override def postStartHook() {
+
+    // The yarn application is running, but the worker might not yet ready
+    // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt
+    Thread.sleep(2000L)
+    logInfo("YarnClientClusterScheduler.postStartHook done")
+  }
+}
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
new file mode 100644
index 0000000000000000000000000000000000000000..b206780c7806e15c84944db05876f89c8f848040
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState}
+import org.apache.spark.{SparkException, Logging, SparkContext}
+import org.apache.spark.deploy.yarn.{Client, ClientArguments}
+
+private[spark] class YarnClientSchedulerBackend(
+    scheduler: ClusterScheduler,
+    sc: SparkContext)
+  extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
+  with Logging {
+
+  var client: Client = null
+  var appId: ApplicationId = null
+
+  override def start() {
+    super.start()
+
+    val defalutWorkerCores = "2"
+    val defalutWorkerMemory = "512m"
+    val defaultWorkerNumber = "1"
+
+    val userJar = System.getenv("SPARK_YARN_APP_JAR")
+    var workerCores = System.getenv("SPARK_WORKER_CORES")
+    var workerMemory = System.getenv("SPARK_WORKER_MEMORY")
+    var workerNumber = System.getenv("SPARK_WORKER_INSTANCES")
+
+    if (userJar == null)
+      throw new SparkException("env SPARK_YARN_APP_JAR is not set")
+
+    if (workerCores == null)
+      workerCores = defalutWorkerCores
+    if (workerMemory == null)
+      workerMemory = defalutWorkerMemory
+    if (workerNumber == null)
+      workerNumber = defaultWorkerNumber
+
+    val driverHost = System.getProperty("spark.driver.host")
+    val driverPort = System.getProperty("spark.driver.port")
+    val hostport = driverHost + ":" + driverPort
+
+    val argsArray = Array[String](
+      "--class", "notused",
+      "--jar", userJar,
+      "--args", hostport,
+      "--worker-memory", workerMemory,
+      "--worker-cores", workerCores,
+      "--num-workers", workerNumber,
+      "--master-class", "org.apache.spark.deploy.yarn.WorkerLauncher"
+    )
+
+    val args = new ClientArguments(argsArray)
+    client = new Client(args)
+    appId = client.runApp()
+    waitForApp()
+  }
+
+  def waitForApp() {
+
+    // TODO : need a better way to find out whether the workers are ready or not
+    // maybe by resource usage report?
+    while(true) {
+      val report = client.getApplicationReport(appId)
+
+      logInfo("Application report from ASM: \n" +
+        "\t appMasterRpcPort: " + report.getRpcPort() + "\n" +
+        "\t appStartTime: " + report.getStartTime() + "\n" +
+        "\t yarnAppState: " + report.getYarnApplicationState() + "\n"
+      )
+
+      // Ready to go, or already gone.
+      val state = report.getYarnApplicationState()
+      if (state == YarnApplicationState.RUNNING) {
+        return
+      } else if (state == YarnApplicationState.FINISHED ||
+        state == YarnApplicationState.FAILED ||
+        state == YarnApplicationState.KILLED) {
+        throw new SparkException("Yarn application already ended," +
+          "might be killed or not able to launch application master.")
+      }
+
+      Thread.sleep(1000)
+    }
+  }
+
+  override def stop() {
+    super.stop()
+    client.stop()
+    logInfo("Stoped")
+  }
+
+}