diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 12b4d94a567ceec4713299fa49a54349b08c0cef..132e4fb0d2cadcd2018437f239c1519cee957bbd 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -27,13 +27,12 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark._
 import org.apache.spark.rdd.RDD
-import org.apache.spark.rdd.PipedRDD
 import org.apache.spark.util.Utils
 
 
 private[spark] class PythonRDD[T: ClassManifest](
     parent: RDD[T],
-    command: Seq[String],
+    command: Array[Byte],
     envVars: JMap[String, String],
     pythonIncludes: JList[String],
     preservePartitoning: Boolean,
@@ -44,21 +43,10 @@ private[spark] class PythonRDD[T: ClassManifest](
 
   val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
 
-  // Similar to Runtime.exec(), if we are given a single string, split it into words
-  // using a standard StringTokenizer (i.e. by spaces)
-  def this(parent: RDD[T], command: String, envVars: JMap[String, String],
-      pythonIncludes: JList[String],
-      preservePartitoning: Boolean, pythonExec: String,
-      broadcastVars: JList[Broadcast[Array[Byte]]],
-      accumulator: Accumulator[JList[Array[Byte]]]) =
-    this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec,
-      broadcastVars, accumulator)
-
   override def getPartitions = parent.partitions
 
   override val partitioner = if (preservePartitoning) parent.partitioner else None
 
-
   override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
     val startTime = System.currentTimeMillis
     val env = SparkEnv.get
@@ -71,11 +59,10 @@ private[spark] class PythonRDD[T: ClassManifest](
           SparkEnv.set(env)
           val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
           val dataOut = new DataOutputStream(stream)
-          val printOut = new PrintWriter(stream)
           // Partition index
           dataOut.writeInt(split.index)
           // sparkFilesDir
-          PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
+          dataOut.writeUTF(SparkFiles.getRootDirectory)
           // Broadcast variables
           dataOut.writeInt(broadcastVars.length)
           for (broadcast <- broadcastVars) {
@@ -85,21 +72,16 @@ private[spark] class PythonRDD[T: ClassManifest](
           }
           // Python includes (*.zip and *.egg files)
           dataOut.writeInt(pythonIncludes.length)
-          for (f <- pythonIncludes) {
-            PythonRDD.writeAsPickle(f, dataOut)
-          }
+          pythonIncludes.foreach(dataOut.writeUTF)
           dataOut.flush()
-          // Serialized user code
-          for (elem <- command) {
-            printOut.println(elem)
-          }
-          printOut.flush()
+          // Serialized command:
+          dataOut.writeInt(command.length)
+          dataOut.write(command)
           // Data values
           for (elem <- parent.iterator(split, context)) {
-            PythonRDD.writeAsPickle(elem, dataOut)
+            PythonRDD.writeToStream(elem, dataOut)
           }
           dataOut.flush()
-          printOut.flush()
           worker.shutdownOutput()
         } catch {
           case e: IOException =>
@@ -132,7 +114,7 @@ private[spark] class PythonRDD[T: ClassManifest](
               val obj = new Array[Byte](length)
               stream.readFully(obj)
               obj
-            case -3 =>
+            case SpecialLengths.TIMING_DATA =>
               // Timing data from worker
               val bootTime = stream.readLong()
               val initTime = stream.readLong()
@@ -143,24 +125,24 @@ private[spark] class PythonRDD[T: ClassManifest](
               val total = finishTime - startTime
               logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish))
               read
-            case -2 =>
+            case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
               // Signals that an exception has been thrown in python
               val exLength = stream.readInt()
               val obj = new Array[Byte](exLength)
               stream.readFully(obj)
               throw new PythonException(new String(obj))
-            case -1 =>
+            case SpecialLengths.END_OF_DATA_SECTION =>
               // We've finished the data section of the output, but we can still
-              // read some accumulator updates; let's do that, breaking when we
-              // get a negative length record.
-              var len2 = stream.readInt()
-              while (len2 >= 0) {
-                val update = new Array[Byte](len2)
+              // read some accumulator updates:
+              val numAccumulatorUpdates = stream.readInt()
+              (1 to numAccumulatorUpdates).foreach { _ =>
+                val updateLen = stream.readInt()
+                val update = new Array[Byte](updateLen)
                 stream.readFully(update)
                 accumulator += Collections.singletonList(update)
-                len2 = stream.readInt()
+
               }
-              new Array[Byte](0)
+              Array.empty[Byte]
           }
         } catch {
           case eof: EOFException => {
@@ -197,62 +179,15 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
   val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this)
 }
 
-private[spark] object PythonRDD {
-
-  /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */
-  def stripPickle(arr: Array[Byte]) : Array[Byte] = {
-    arr.slice(2, arr.length - 1)
-  }
+private object SpecialLengths {
+  val END_OF_DATA_SECTION = -1
+  val PYTHON_EXCEPTION_THROWN = -2
+  val TIMING_DATA = -3
+}
 
-  /**
-   * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream.
-   * The data format is a 32-bit integer representing the pickled object's length (in bytes),
-   * followed by the pickled data.
-   *
-   * Pickle module:
-   *
-   *    http://docs.python.org/2/library/pickle.html
-   *
-   * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules:
-   *
-   *    http://hg.python.org/cpython/file/2.6/Lib/pickle.py
-   *    http://hg.python.org/cpython/file/2.6/Lib/pickletools.py
-   *
-   * @param elem the object to write
-   * @param dOut a data output stream
-   */
-  def writeAsPickle(elem: Any, dOut: DataOutputStream) {
-    if (elem.isInstanceOf[Array[Byte]]) {
-      val arr = elem.asInstanceOf[Array[Byte]]
-      dOut.writeInt(arr.length)
-      dOut.write(arr)
-    } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
-      val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
-      val length = t._1.length + t._2.length - 3 - 3 + 4  // stripPickle() removes 3 bytes
-      dOut.writeInt(length)
-      dOut.writeByte(Pickle.PROTO)
-      dOut.writeByte(Pickle.TWO)
-      dOut.write(PythonRDD.stripPickle(t._1))
-      dOut.write(PythonRDD.stripPickle(t._2))
-      dOut.writeByte(Pickle.TUPLE2)
-      dOut.writeByte(Pickle.STOP)
-    } else if (elem.isInstanceOf[String]) {
-      // For uniformity, strings are wrapped into Pickles.
-      val s = elem.asInstanceOf[String].getBytes("UTF-8")
-      val length = 2 + 1 + 4 + s.length + 1
-      dOut.writeInt(length)
-      dOut.writeByte(Pickle.PROTO)
-      dOut.writeByte(Pickle.TWO)
-      dOut.write(Pickle.BINUNICODE)
-      dOut.writeInt(Integer.reverseBytes(s.length))
-      dOut.write(s)
-      dOut.writeByte(Pickle.STOP)
-    } else {
-      throw new SparkException("Unexpected RDD type")
-    }
-  }
+private[spark] object PythonRDD {
 
-  def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
+  def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
   JavaRDD[Array[Byte]] = {
     val file = new DataInputStream(new FileInputStream(filename))
     val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
@@ -270,15 +205,32 @@ private[spark] object PythonRDD {
     JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
   }
 
-  def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
+  def writeToStream(elem: Any, dataOut: DataOutputStream) {
+    elem match {
+      case bytes: Array[Byte] =>
+        dataOut.writeInt(bytes.length)
+        dataOut.write(bytes)
+      case pair: (Array[Byte], Array[Byte]) =>
+        dataOut.writeInt(pair._1.length)
+        dataOut.write(pair._1)
+        dataOut.writeInt(pair._2.length)
+        dataOut.write(pair._2)
+      case str: String =>
+        dataOut.writeUTF(str)
+      case other =>
+        throw new SparkException("Unexpected element type " + other.getClass)
+    }
+  }
+
+  def writeToFile[T](items: java.util.Iterator[T], filename: String) {
     import scala.collection.JavaConverters._
-    writeIteratorToPickleFile(items.asScala, filename)
+    writeToFile(items.asScala, filename)
   }
 
-  def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) {
+  def writeToFile[T](items: Iterator[T], filename: String) {
     val file = new DataOutputStream(new FileOutputStream(filename))
     for (item <- items) {
-      writeAsPickle(item, file)
+      writeToStream(item, file)
     }
     file.close()
   }
@@ -289,17 +241,6 @@ private[spark] object PythonRDD {
   }
 }
 
-private object Pickle {
-  val PROTO: Byte = 0x80.toByte
-  val TWO: Byte = 0x02.toByte
-  val BINUNICODE: Byte = 'X'
-  val STOP: Byte = '.'
-  val TUPLE2: Byte = 0x86.toByte
-  val EMPTY_LIST: Byte = ']'
-  val MARK: Byte = '('
-  val APPENDS: Byte = 'e'
-}
-
 private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
   override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
 }
diff --git a/python/epydoc.conf b/python/epydoc.conf
index 1d0d002d36623409f341688811d4d846ec6ee9fe..0b42e729f8dcc756c711584de2b2a4f071b480c5 100644
--- a/python/epydoc.conf
+++ b/python/epydoc.conf
@@ -32,6 +32,6 @@ target: docs/
 
 private: no
 
-exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
+exclude: pyspark.cloudpickle pyspark.worker pyspark.join
          pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test
          pyspark.rddsampler pyspark.daemon
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index da3d96689aa15dc14707e8f88c9170e51af3cede..2204e9c9ca7011f10ca60c900896db13868c35ff 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -90,9 +90,11 @@ import struct
 import SocketServer
 import threading
 from pyspark.cloudpickle import CloudPickler
-from pyspark.serializers import read_int, read_with_length, load_pickle
+from pyspark.serializers import read_int, PickleSerializer
 
 
+pickleSer = PickleSerializer()
+
 # Holds accumulators registered on the current machine, keyed by ID. This is then used to send
 # the local accumulator updates back to the driver program at the end of a task.
 _accumulatorRegistry = {}
@@ -211,7 +213,7 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
         from pyspark.accumulators import _accumulatorRegistry
         num_updates = read_int(self.rfile)
         for _ in range(num_updates):
-            (aid, update) = load_pickle(read_with_length(self.rfile))
+            (aid, update) = pickleSer._read_with_length(self.rfile)
             _accumulatorRegistry[aid] += update
         # Write a byte in acknowledgement
         self.wfile.write(struct.pack("!b", 1))
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index a7ca8bc888c6759aff5784d26ad7df015d2fe2f4..cbd41e58c4a780392b2a6b8c58320535e416cd36 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -26,7 +26,7 @@ from pyspark.accumulators import Accumulator
 from pyspark.broadcast import Broadcast
 from pyspark.files import SparkFiles
 from pyspark.java_gateway import launch_gateway
-from pyspark.serializers import dump_pickle, write_with_length, batched
+from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer
 from pyspark.storagelevel import StorageLevel
 from pyspark.rdd import RDD
 
@@ -42,7 +42,7 @@ class SparkContext(object):
 
     _gateway = None
     _jvm = None
-    _writeIteratorToPickleFile = None
+    _writeToFile = None
     _takePartition = None
     _next_accum_id = 0
     _active_spark_context = None
@@ -51,7 +51,7 @@ class SparkContext(object):
 
 
     def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
-        environment=None, batchSize=1024):
+        environment=None, batchSize=1024, serializer=PickleSerializer()):
         """
         Create a new SparkContext.
 
@@ -67,6 +67,7 @@ class SparkContext(object):
         @param batchSize: The number of Python objects represented as a single
                Java object.  Set 1 to disable batching or -1 to use an
                unlimited batch size.
+        @param serializer: The serializer for RDDs.
 
 
         >>> from pyspark.context import SparkContext
@@ -83,7 +84,13 @@ class SparkContext(object):
         self.jobName = jobName
         self.sparkHome = sparkHome or None # None becomes null in Py4J
         self.environment = environment or {}
-        self.batchSize = batchSize  # -1 represents a unlimited batch size
+        self._batchSize = batchSize  # -1 represents an unlimited batch size
+        self._unbatched_serializer = serializer
+        if batchSize == 1:
+            self.serializer = self._unbatched_serializer
+        else:
+            self.serializer = BatchedSerializer(self._unbatched_serializer,
+                                                batchSize)
 
         # Create the Java SparkContext through Py4J
         empty_string_array = self._gateway.new_array(self._jvm.String, 0)
@@ -125,8 +132,8 @@ class SparkContext(object):
             if not SparkContext._gateway:
                 SparkContext._gateway = launch_gateway()
                 SparkContext._jvm = SparkContext._gateway.jvm
-                SparkContext._writeIteratorToPickleFile = \
-                    SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
+                SparkContext._writeToFile = \
+                    SparkContext._jvm.PythonRDD.writeToFile
                 SparkContext._takePartition = \
                     SparkContext._jvm.PythonRDD.takePartition
 
@@ -184,15 +191,17 @@ class SparkContext(object):
         # Make sure we distribute data evenly if it's smaller than self.batchSize
         if "__len__" not in dir(c):
             c = list(c)    # Make it a list so we can compute its length
-        batchSize = min(len(c) // numSlices, self.batchSize)
+        batchSize = min(len(c) // numSlices, self._batchSize)
         if batchSize > 1:
-            c = batched(c, batchSize)
-        for x in c:
-            write_with_length(dump_pickle(x), tempFile)
+            serializer = BatchedSerializer(self._unbatched_serializer,
+                                           batchSize)
+        else:
+            serializer = self._unbatched_serializer
+        serializer.dump_stream(c, tempFile)
         tempFile.close()
-        readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
-        jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
-        return RDD(jrdd, self)
+        readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
+        jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
+        return RDD(jrdd, self, serializer)
 
     def textFile(self, name, minSplits=None):
         """
@@ -201,21 +210,39 @@ class SparkContext(object):
         RDD of Strings.
         """
         minSplits = minSplits or min(self.defaultParallelism, 2)
-        jrdd = self._jsc.textFile(name, minSplits)
-        return RDD(jrdd, self)
+        return RDD(self._jsc.textFile(name, minSplits), self,
+                   MUTF8Deserializer())
 
-    def _checkpointFile(self, name):
+    def _checkpointFile(self, name, input_deserializer):
         jrdd = self._jsc.checkpointFile(name)
-        return RDD(jrdd, self)
+        return RDD(jrdd, self, input_deserializer)
 
     def union(self, rdds):
         """
         Build the union of a list of RDDs.
+
+        This supports unions() of RDDs with different serialized formats,
+        although this forces them to be reserialized using the default
+        serializer:
+
+        >>> path = os.path.join(tempdir, "union-text.txt")
+        >>> with open(path, "w") as testFile:
+        ...    testFile.write("Hello")
+        >>> textFile = sc.textFile(path)
+        >>> textFile.collect()
+        [u'Hello']
+        >>> parallelized = sc.parallelize(["World!"])
+        >>> sorted(sc.union([textFile, parallelized]).collect())
+        [u'Hello', 'World!']
         """
+        first_jrdd_deserializer = rdds[0]._jrdd_deserializer
+        if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
+            rdds = [x._reserialize() for x in rdds]
         first = rdds[0]._jrdd
         rest = [x._jrdd for x in rdds[1:]]
-        rest = ListConverter().convert(rest, self.gateway._gateway_client)
-        return RDD(self._jsc.union(first, rest), self)
+        rest = ListConverter().convert(rest, self._gateway._gateway_client)
+        return RDD(self._jsc.union(first, rest), self,
+                   rdds[0]._jrdd_deserializer)
 
     def broadcast(self, value):
         """
@@ -223,7 +250,9 @@ class SparkContext(object):
         object for reading it in distributed functions. The variable will be
         sent to each cluster only once.
         """
-        jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
+        pickleSer = PickleSerializer()
+        pickled = pickleSer.dumps(value)
+        jbroadcast = self._jsc.broadcast(bytearray(pickled))
         return Broadcast(jbroadcast.id(), value, jbroadcast,
                          self._pickled_broadcast_vars)
 
@@ -235,7 +264,7 @@ class SparkContext(object):
         and floating-point numbers if you do not provide one. For other types,
         a custom AccumulatorParam can be used.
         """
-        if accum_param == None:
+        if accum_param is None:
             if isinstance(value, int):
                 accum_param = accumulators.INT_ACCUMULATOR_PARAM
             elif isinstance(value, float):
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 7019fb8beefc8f5c8c2060899e2c28e8ae10643e..957f3f89c0eb2dda5afbda8ca3c9a850727308f7 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -18,7 +18,7 @@
 from base64 import standard_b64encode as b64enc
 import copy
 from collections import defaultdict
-from itertools import chain, ifilter, imap, product
+from itertools import chain, ifilter, imap
 import operator
 import os
 import sys
@@ -27,9 +27,8 @@ from subprocess import Popen, PIPE
 from tempfile import NamedTemporaryFile
 from threading import Thread
 
-from pyspark import cloudpickle
-from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
-    read_from_pickle_file, pack_long
+from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
+    BatchedSerializer, CloudPickleSerializer, pack_long
 from pyspark.join import python_join, python_left_outer_join, \
     python_right_outer_join, python_cogroup
 from pyspark.statcounter import StatCounter
@@ -48,12 +47,12 @@ class RDD(object):
     operated on in parallel.
     """
 
-    def __init__(self, jrdd, ctx):
+    def __init__(self, jrdd, ctx, jrdd_deserializer):
         self._jrdd = jrdd
         self.is_cached = False
         self.is_checkpointed = False
         self.ctx = ctx
-        self._partitionFunc = None
+        self._jrdd_deserializer = jrdd_deserializer
 
     @property
     def context(self):
@@ -247,7 +246,23 @@ class RDD(object):
         >>> rdd.union(rdd).collect()
         [1, 1, 2, 3, 1, 1, 2, 3]
         """
-        return RDD(self._jrdd.union(other._jrdd), self.ctx)
+        if self._jrdd_deserializer == other._jrdd_deserializer:
+            rdd = RDD(self._jrdd.union(other._jrdd), self.ctx,
+                      self._jrdd_deserializer)
+            return rdd
+        else:
+            # These RDDs contain data in different serialized formats, so we
+            # must normalize them to the default serializer.
+            self_copy = self._reserialize()
+            other_copy = other._reserialize()
+            return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
+                       self.ctx.serializer)
+
+    def _reserialize(self):
+        if self._jrdd_deserializer == self.ctx.serializer:
+            return self
+        else:
+            return self.map(lambda x: x, preservesPartitioning=True)
 
     def __add__(self, other):
         """
@@ -334,17 +349,9 @@ class RDD(object):
         [(1, 1), (1, 2), (2, 1), (2, 2)]
         """
         # Due to batching, we can't use the Java cartesian method.
-        java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
-        def unpack_batches(pair):
-            (x, y) = pair
-            if type(x) == Batch or type(y) == Batch:
-                xs = x.items if type(x) == Batch else [x]
-                ys = y.items if type(y) == Batch else [y]
-                for pair in product(xs, ys):
-                    yield pair
-            else:
-                yield pair
-        return java_cartesian.flatMap(unpack_batches)
+        deserializer = CartesianDeserializer(self._jrdd_deserializer,
+                                             other._jrdd_deserializer)
+        return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer)
 
     def groupBy(self, f, numPartitions=None):
         """
@@ -391,8 +398,8 @@ class RDD(object):
         """
         Return a list that contains all of the elements in this RDD.
         """
-        picklesInJava = self._jrdd.collect().iterator()
-        return list(self._collect_iterator_through_file(picklesInJava))
+        bytesInJava = self._jrdd.collect().iterator()
+        return list(self._collect_iterator_through_file(bytesInJava))
 
     def _collect_iterator_through_file(self, iterator):
         # Transferring lots of data through Py4J can be slow because
@@ -400,10 +407,10 @@ class RDD(object):
         # file and read it back.
         tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
         tempFile.close()
-        self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
+        self.ctx._writeToFile(iterator, tempFile.name)
         # Read the data into Python and deserialize it:
         with open(tempFile.name, 'rb') as tempFile:
-            for item in read_from_pickle_file(tempFile):
+            for item in self._jrdd_deserializer.load_stream(tempFile):
                 yield item
         os.unlink(tempFile.name)
 
@@ -571,7 +578,7 @@ class RDD(object):
         items = []
         for partition in range(mapped._jrdd.splits().size()):
             iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
-            items.extend(self._collect_iterator_through_file(iterator))
+            items.extend(mapped._collect_iterator_through_file(iterator))
             if len(items) >= num:
                 break
         return items[:num]
@@ -735,6 +742,7 @@ class RDD(object):
         # Transferring O(n) objects to Java is too expensive.  Instead, we'll
         # form the hash buckets in Python, transferring O(numPartitions) objects
         # to Java.  Each object is a (splitNumber, [objects]) pair.
+        outputSerializer = self.ctx._unbatched_serializer
         def add_shuffle_key(split, iterator):
 
             buckets = defaultdict(list)
@@ -743,14 +751,14 @@ class RDD(object):
                 buckets[partitionFunc(k) % numPartitions].append((k, v))
             for (split, items) in buckets.iteritems():
                 yield pack_long(split)
-                yield dump_pickle(Batch(items))
+                yield outputSerializer.dumps(items)
         keyed = PipelinedRDD(self, add_shuffle_key)
         keyed._bypass_serializer = True
         pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
         partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
                                                      id(partitionFunc))
         jrdd = pairRDD.partitionBy(partitioner).values()
-        rdd = RDD(jrdd, self.ctx)
+        rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
         # This is required so that id(partitionFunc) remains unique, even if
         # partitionFunc is a lambda:
         rdd._partitionFunc = partitionFunc
@@ -787,7 +795,8 @@ class RDD(object):
             numPartitions = self.ctx.defaultParallelism
         def combineLocally(iterator):
             combiners = {}
-            for (k, v) in iterator:
+            for x in iterator:
+                (k, v) = x
                 if k not in combiners:
                     combiners[k] = createCombiner(v)
                 else:
@@ -929,38 +938,39 @@ class PipelinedRDD(RDD):
     20
     """
     def __init__(self, prev, func, preservesPartitioning=False):
-        if isinstance(prev, PipelinedRDD) and prev._is_pipelinable():
+        if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable():
+            # This transformation is the first in its stage:
+            self.func = func
+            self.preservesPartitioning = preservesPartitioning
+            self._prev_jrdd = prev._jrdd
+            self._prev_jrdd_deserializer = prev._jrdd_deserializer
+        else:
             prev_func = prev.func
             def pipeline_func(split, iterator):
                 return func(split, prev_func(split, iterator))
             self.func = pipeline_func
             self.preservesPartitioning = \
                 prev.preservesPartitioning and preservesPartitioning
-            self._prev_jrdd = prev._prev_jrdd
-        else:
-            self.func = func
-            self.preservesPartitioning = preservesPartitioning
-            self._prev_jrdd = prev._jrdd
+            self._prev_jrdd = prev._prev_jrdd  # maintain the pipeline
+            self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer
         self.is_cached = False
         self.is_checkpointed = False
         self.ctx = prev.ctx
         self.prev = prev
         self._jrdd_val = None
+        self._jrdd_deserializer = self.ctx.serializer
         self._bypass_serializer = False
 
     @property
     def _jrdd(self):
         if self._jrdd_val:
             return self._jrdd_val
-        func = self.func
-        if not self._bypass_serializer and self.ctx.batchSize != 1:
-            oldfunc = self.func
-            batchSize = self.ctx.batchSize
-            def batched_func(split, iterator):
-                return batched(oldfunc(split, iterator), batchSize)
-            func = batched_func
-        cmds = [func, self._bypass_serializer]
-        pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
+        if self._bypass_serializer:
+            serializer = NoOpSerializer()
+        else:
+            serializer = self.ctx.serializer
+        command = (self.func, self._prev_jrdd_deserializer, serializer)
+        pickled_command = CloudPickleSerializer().dumps(command)
         broadcast_vars = ListConverter().convert(
             [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
             self.ctx._gateway._gateway_client)
@@ -971,8 +981,9 @@ class PipelinedRDD(RDD):
         includes = ListConverter().convert(self.ctx._python_includes,
                                      self.ctx._gateway._gateway_client)
         python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
-            pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec,
-            broadcast_vars, self.ctx._javaAccumulator, class_manifest)
+            bytearray(pickled_command), env, includes, self.preservesPartitioning,
+            self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator,
+            class_manifest)
         self._jrdd_val = python_rdd.asJavaRDD()
         return self._jrdd_val
 
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 54fed1c9c70f66e503abb5c523d6327bb9bae8b4..811fa6f018b23f3c9883bd2a770f03c044786850 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -15,45 +15,269 @@
 # limitations under the License.
 #
 
-import struct
+"""
+PySpark supports custom serializers for transferring data; this can improve
+performance.
+
+By default, PySpark uses L{PickleSerializer} to serialize objects using Python's
+C{cPickle} serializer, which can serialize nearly any Python object.
+Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be
+faster.
+
+The serializer is chosen when creating L{SparkContext}:
+
+>>> from pyspark.context import SparkContext
+>>> from pyspark.serializers import MarshalSerializer
+>>> sc = SparkContext('local', 'test', serializer=MarshalSerializer())
+>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
+[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
+>>> sc.stop()
+
+By default, PySpark serialize objects in batches; the batch size can be
+controlled through SparkContext's C{batchSize} parameter
+(the default size is 1024 objects):
+
+>>> sc = SparkContext('local', 'test', batchSize=2)
+>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
+
+Behind the scenes, this creates a JavaRDD with four partitions, each of
+which contains two batches of two objects:
+
+>>> rdd.glom().collect()
+[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
+>>> rdd._jrdd.count()
+8L
+>>> sc.stop()
+
+A batch size of -1 uses an unlimited batch size, and a size of 1 disables
+batching:
+
+>>> sc = SparkContext('local', 'test', batchSize=1)
+>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
+>>> rdd.glom().collect()
+[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
+>>> rdd._jrdd.count()
+16L
+"""
+
 import cPickle
+from itertools import chain, izip, product
+import marshal
+import struct
+from pyspark import cloudpickle
+
+
+__all__ = ["PickleSerializer", "MarshalSerializer"]
+
+
+class SpecialLengths(object):
+    END_OF_DATA_SECTION = -1
+    PYTHON_EXCEPTION_THROWN = -2
+    TIMING_DATA = -3
+
+
+class Serializer(object):
+
+    def dump_stream(self, iterator, stream):
+        """
+        Serialize an iterator of objects to the output stream.
+        """
+        raise NotImplementedError
+
+    def load_stream(self, stream):
+        """
+        Return an iterator of deserialized objects from the input stream.
+        """
+        raise NotImplementedError
+
+
+    def _load_stream_without_unbatching(self, stream):
+        return self.load_stream(stream)
+
+    # Note: our notion of "equality" is that output generated by
+    # equal serializers can be deserialized using the same serializer.
+
+    # This default implementation handles the simple cases;
+    # subclasses should override __eq__ as appropriate.
+
+    def __eq__(self, other):
+        return isinstance(other, self.__class__)
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+
+class FramedSerializer(Serializer):
+    """
+    Serializer that writes objects as a stream of (length, data) pairs,
+    where C{length} is a 32-bit integer and data is C{length} bytes.
+    """
+
+    def dump_stream(self, iterator, stream):
+        for obj in iterator:
+            self._write_with_length(obj, stream)
+
+    def load_stream(self, stream):
+        while True:
+            try:
+                yield self._read_with_length(stream)
+            except EOFError:
+                return
+
+    def _write_with_length(self, obj, stream):
+        serialized = self.dumps(obj)
+        write_int(len(serialized), stream)
+        stream.write(serialized)
+
+    def _read_with_length(self, stream):
+        length = read_int(stream)
+        obj = stream.read(length)
+        if obj == "":
+            raise EOFError
+        return self.loads(obj)
+
+    def dumps(self, obj):
+        """
+        Serialize an object into a byte array.
+        When batching is used, this will be called with an array of objects.
+        """
+        raise NotImplementedError
+
+    def loads(self, obj):
+        """
+        Deserialize an object from a byte array.
+        """
+        raise NotImplementedError
+
+
+class BatchedSerializer(Serializer):
+    """
+    Serializes a stream of objects in batches by calling its wrapped
+    Serializer with streams of objects.
+    """
+
+    UNLIMITED_BATCH_SIZE = -1
+
+    def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE):
+        self.serializer = serializer
+        self.batchSize = batchSize
+
+    def _batched(self, iterator):
+        if self.batchSize == self.UNLIMITED_BATCH_SIZE:
+            yield list(iterator)
+        else:
+            items = []
+            count = 0
+            for item in iterator:
+                items.append(item)
+                count += 1
+                if count == self.batchSize:
+                    yield items
+                    items = []
+                    count = 0
+            if items:
+                yield items
+
+    def dump_stream(self, iterator, stream):
+        self.serializer.dump_stream(self._batched(iterator), stream)
+
+    def load_stream(self, stream):
+        return chain.from_iterable(self._load_stream_without_unbatching(stream))
+
+    def _load_stream_without_unbatching(self, stream):
+            return self.serializer.load_stream(stream)
+
+    def __eq__(self, other):
+        return isinstance(other, BatchedSerializer) and \
+               other.serializer == self.serializer
+
+    def __str__(self):
+        return "BatchedSerializer<%s>" % str(self.serializer)
 
 
-class Batch(object):
+class CartesianDeserializer(FramedSerializer):
     """
-    Used to store multiple RDD entries as a single Java object.
+    Deserializes the JavaRDD cartesian() of two PythonRDDs.
+    """
+
+    def __init__(self, key_ser, val_ser):
+        self.key_ser = key_ser
+        self.val_ser = val_ser
+
+    def load_stream(self, stream):
+        key_stream = self.key_ser._load_stream_without_unbatching(stream)
+        val_stream = self.val_ser._load_stream_without_unbatching(stream)
+        key_is_batched = isinstance(self.key_ser, BatchedSerializer)
+        val_is_batched = isinstance(self.val_ser, BatchedSerializer)
+        for (keys, vals) in izip(key_stream, val_stream):
+            keys = keys if key_is_batched else [keys]
+            vals = vals if val_is_batched else [vals]
+            for pair in product(keys, vals):
+                yield pair
+
+    def __eq__(self, other):
+        return isinstance(other, CartesianDeserializer) and \
+               self.key_ser == other.key_ser and self.val_ser == other.val_ser
+
+    def __str__(self):
+        return "CartesianDeserializer<%s, %s>" % \
+               (str(self.key_ser), str(self.val_ser))
 
-    This relieves us from having to explicitly track whether an RDD
-    is stored as batches of objects and avoids problems when processing
-    the union() of batched and unbatched RDDs (e.g. the union() of textFile()
-    with another RDD).
+
+class NoOpSerializer(FramedSerializer):
+
+    def loads(self, obj): return obj
+    def dumps(self, obj): return obj
+
+
+class PickleSerializer(FramedSerializer):
     """
-    def __init__(self, items):
-        self.items = items
+    Serializes objects using Python's cPickle serializer:
 
+        http://docs.python.org/2/library/pickle.html
 
-def batched(iterator, batchSize):
-    if batchSize == -1: # unlimited batch size
-        yield Batch(list(iterator))
-    else:
-        items = []
-        count = 0
-        for item in iterator:
-            items.append(item)
-            count += 1
-            if count == batchSize:
-                yield Batch(items)
-                items = []
-                count = 0
-        if items:
-            yield Batch(items)
+    This serializer supports nearly any Python object, but may
+    not be as fast as more specialized serializers.
+    """
 
+    def dumps(self, obj): return cPickle.dumps(obj, 2)
+    loads = cPickle.loads
 
-def dump_pickle(obj):
-    return cPickle.dumps(obj, 2)
+class CloudPickleSerializer(PickleSerializer):
 
+    def dumps(self, obj): return cloudpickle.dumps(obj, 2)
 
-load_pickle = cPickle.loads
+
+class MarshalSerializer(FramedSerializer):
+    """
+    Serializes objects using Python's Marshal serializer:
+
+        http://docs.python.org/2/library/marshal.html
+
+    This serializer is faster than PickleSerializer but supports fewer datatypes.
+    """
+
+    dumps = marshal.dumps
+    loads = marshal.loads
+
+
+class MUTF8Deserializer(Serializer):
+    """
+    Deserializes streams written by Java's DataOutputStream.writeUTF().
+    """
+
+    def loads(self, stream):
+        length = struct.unpack('>H', stream.read(2))[0]
+        return stream.read(length).decode('utf8')
+
+    def load_stream(self, stream):
+        while True:
+            try:
+                yield self.loads(stream)
+            except struct.error:
+                return
+            except EOFError:
+                return
 
 
 def read_long(stream):
@@ -84,25 +308,4 @@ def write_int(value, stream):
 
 def write_with_length(obj, stream):
     write_int(len(obj), stream)
-    stream.write(obj)
-
-
-def read_with_length(stream):
-    length = read_int(stream)
-    obj = stream.read(length)
-    if obj == "":
-        raise EOFError
-    return obj
-
-
-def read_from_pickle_file(stream):
-    try:
-        while True:
-            obj = load_pickle(read_with_length(stream))
-            if type(obj) == Batch:  # We don't care about inheritance
-                for item in obj.items:
-                    yield item
-            else:
-                yield obj
-    except EOFError:
-        return
+    stream.write(obj)
\ No newline at end of file
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 29d6a128f6a9b83dc742c676ef010a90f54ab73e..621e1cb58c3df10afa1f64a8f7b9f988dd71b0cb 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -86,7 +86,8 @@ class TestCheckpoint(PySparkTestCase):
         time.sleep(1)  # 1 second
 
         self.assertTrue(flatMappedRDD.getCheckpointFile() is not None)
-        recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
+        recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(),
+                                            flatMappedRDD._jrdd_deserializer)
         self.assertEquals([1, 2, 3, 4], recovered.collect())
 
 
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index d63c2aaef772de62eef3bf913ad4a4859cf30512..f2b3f3c1421d12d48637ca47b3ac39ed61bbcfac 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -23,23 +23,22 @@ import sys
 import time
 import socket
 import traceback
-from base64 import standard_b64decode
 # CloudPickler needs to be imported so that depicklers are registered using the
 # copy_reg module.
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.broadcast import Broadcast, _broadcastRegistry
 from pyspark.cloudpickle import CloudPickler
 from pyspark.files import SparkFiles
-from pyspark.serializers import write_with_length, read_with_length, write_int, \
-    read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
+from pyspark.serializers import write_with_length, write_int, read_long, \
+    write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer
 
 
-def load_obj(infile):
-    return load_pickle(standard_b64decode(infile.readline().strip()))
+pickleSer = PickleSerializer()
+mutf8_deserializer = MUTF8Deserializer()
 
 
 def report_times(outfile, boot, init, finish):
-    write_int(-3, outfile)
+    write_int(SpecialLengths.TIMING_DATA, outfile)
     write_long(1000 * boot, outfile)
     write_long(1000 * init, outfile)
     write_long(1000 * finish, outfile)
@@ -52,7 +51,7 @@ def main(infile, outfile):
         return
 
     # fetch name of workdir
-    spark_files_dir = load_pickle(read_with_length(infile))
+    spark_files_dir = mutf8_deserializer.loads(infile)
     SparkFiles._root_directory = spark_files_dir
     SparkFiles._is_running_on_worker = True
 
@@ -60,38 +59,33 @@ def main(infile, outfile):
     num_broadcast_variables = read_int(infile)
     for _ in range(num_broadcast_variables):
         bid = read_long(infile)
-        value = read_with_length(infile)
-        _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
+        value = pickleSer._read_with_length(infile)
+        _broadcastRegistry[bid] = Broadcast(bid, value)
 
     # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
     sys.path.append(spark_files_dir) # *.py files that were added will be copied here
     num_python_includes =  read_int(infile)
     for _ in range(num_python_includes):
-        sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile))))
+        filename = mutf8_deserializer.loads(infile)
+        sys.path.append(os.path.join(spark_files_dir, filename))
 
-    # now load function
-    func = load_obj(infile)
-    bypassSerializer = load_obj(infile)
-    if bypassSerializer:
-        dumps = lambda x: x
-    else:
-        dumps = dump_pickle
+    command = pickleSer._read_with_length(infile)
+    (func, deserializer, serializer) = command
     init_time = time.time()
-    iterator = read_from_pickle_file(infile)
     try:
-        for obj in func(split_index, iterator):
-            write_with_length(dumps(obj), outfile)
+        iterator = deserializer.load_stream(infile)
+        serializer.dump_stream(func(split_index, iterator), outfile)
     except Exception as e:
-        write_int(-2, outfile)
+        write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
         write_with_length(traceback.format_exc(), outfile)
         sys.exit(-1)
     finish_time = time.time()
     report_times(outfile, boot_time, init_time, finish_time)
     # Mark the beginning of the accumulators section of the output
-    write_int(-1, outfile)
-    for aid, accum in _accumulatorRegistry.items():
-        write_with_length(dump_pickle((aid, accum._value)), outfile)
-    write_int(-1, outfile)
+    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
+    write_int(len(_accumulatorRegistry), outfile)
+    for (aid, accum) in _accumulatorRegistry.items():
+        pickleSer._write_with_length((aid, accum._value), outfile)
 
 
 if __name__ == '__main__':
diff --git a/python/run-tests b/python/run-tests
index cbc554ea9db0d2cdd18323a408879d98f80810cb..d4dad672d284299a0a1fc116b24f7833eefe1b7b 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -37,6 +37,7 @@ run_test "pyspark/rdd.py"
 run_test "pyspark/context.py"
 run_test "-m doctest pyspark/broadcast.py"
 run_test "-m doctest pyspark/accumulators.py"
+run_test "-m doctest pyspark/serializers.py"
 run_test "pyspark/tests.py"
 
 if [[ $FAILED != 0 ]]; then