diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index eb0b0db0cc2efb9331b12c0e6e09778d115a7883..ef9bf4db9b1b1d60373a9fa6ca842b2bee055ef2 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -221,18 +221,6 @@ private[spark] object PythonRDD {
     JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
   }
 
-  def writeStringAsPickle(elem: String, dOut: DataOutputStream) {
-    val s = elem.getBytes("UTF-8")
-    val length = 2 + 1 + 4 + s.length + 1
-    dOut.writeInt(length)
-    dOut.writeByte(Pickle.PROTO)
-    dOut.writeByte(Pickle.TWO)
-    dOut.write(Pickle.BINUNICODE)
-    dOut.writeInt(Integer.reverseBytes(s.length))
-    dOut.write(s)
-    dOut.writeByte(Pickle.STOP)
-  }
-
   def writeToStream(elem: Any, dataOut: DataOutputStream) {
     elem match {
       case bytes: Array[Byte] =>
@@ -244,9 +232,7 @@ private[spark] object PythonRDD {
         dataOut.writeInt(pair._2.length)
         dataOut.write(pair._2)
       case str: String =>
-        // Until we've implemented full custom serializer support, we need to return
-        // strings as Pickles to properly support union() and cartesian():
-        writeStringAsPickle(str, dataOut)
+        dataOut.writeUTF(str)
       case other =>
         throw new SparkException("Unexpected element type " + other.getClass)
     }
@@ -271,13 +257,6 @@ private[spark] object PythonRDD {
   }
 }
 
-private object Pickle {
-  val PROTO: Byte = 0x80.toByte
-  val TWO: Byte = 0x02.toByte
-  val BINUNICODE: Byte = 'X'
-  val STOP: Byte = '.'
-}
-
 private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
   override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
 }
diff --git a/python/epydoc.conf b/python/epydoc.conf
index 1d0d002d36623409f341688811d4d846ec6ee9fe..0b42e729f8dcc756c711584de2b2a4f071b480c5 100644
--- a/python/epydoc.conf
+++ b/python/epydoc.conf
@@ -32,6 +32,6 @@ target: docs/
 
 private: no
 
-exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
+exclude: pyspark.cloudpickle pyspark.worker pyspark.join
          pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test
          pyspark.rddsampler pyspark.daemon
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index da3d96689aa15dc14707e8f88c9170e51af3cede..2204e9c9ca7011f10ca60c900896db13868c35ff 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -90,9 +90,11 @@ import struct
 import SocketServer
 import threading
 from pyspark.cloudpickle import CloudPickler
-from pyspark.serializers import read_int, read_with_length, load_pickle
+from pyspark.serializers import read_int, PickleSerializer
 
 
+pickleSer = PickleSerializer()
+
 # Holds accumulators registered on the current machine, keyed by ID. This is then used to send
 # the local accumulator updates back to the driver program at the end of a task.
 _accumulatorRegistry = {}
@@ -211,7 +213,7 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
         from pyspark.accumulators import _accumulatorRegistry
         num_updates = read_int(self.rfile)
         for _ in range(num_updates):
-            (aid, update) = load_pickle(read_with_length(self.rfile))
+            (aid, update) = pickleSer._read_with_length(self.rfile)
             _accumulatorRegistry[aid] += update
         # Write a byte in acknowledgement
         self.wfile.write(struct.pack("!b", 1))
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 0fec1a6bf67d4d2afeeb06cd605c64f3055f722b..6bb1c6c3a1e4bc77e924bd6f7365ed981b4a73a4 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -26,7 +26,7 @@ from pyspark.accumulators import Accumulator
 from pyspark.broadcast import Broadcast
 from pyspark.files import SparkFiles
 from pyspark.java_gateway import launch_gateway
-from pyspark.serializers import dump_pickle, write_with_length, batched
+from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer
 from pyspark.storagelevel import StorageLevel
 from pyspark.rdd import RDD
 
@@ -51,7 +51,7 @@ class SparkContext(object):
 
 
     def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
-        environment=None, batchSize=1024):
+        environment=None, batchSize=1024, serializer=PickleSerializer()):
         """
         Create a new SparkContext.
 
@@ -67,6 +67,7 @@ class SparkContext(object):
         @param batchSize: The number of Python objects represented as a single
                Java object.  Set 1 to disable batching or -1 to use an
                unlimited batch size.
+        @param serializer: The serializer for RDDs.
 
 
         >>> from pyspark.context import SparkContext
@@ -83,7 +84,13 @@ class SparkContext(object):
         self.jobName = jobName
         self.sparkHome = sparkHome or None # None becomes null in Py4J
         self.environment = environment or {}
-        self.batchSize = batchSize  # -1 represents a unlimited batch size
+        self._batchSize = batchSize  # -1 represents an unlimited batch size
+        self._unbatched_serializer = serializer
+        if batchSize == 1:
+            self.serializer = self._unbatched_serializer
+        else:
+            self.serializer = BatchedSerializer(self._unbatched_serializer,
+                                                batchSize)
 
         # Create the Java SparkContext through Py4J
         empty_string_array = self._gateway.new_array(self._jvm.String, 0)
@@ -184,15 +191,17 @@ class SparkContext(object):
         # Make sure we distribute data evenly if it's smaller than self.batchSize
         if "__len__" not in dir(c):
             c = list(c)    # Make it a list so we can compute its length
-        batchSize = min(len(c) // numSlices, self.batchSize)
+        batchSize = min(len(c) // numSlices, self._batchSize)
         if batchSize > 1:
-            c = batched(c, batchSize)
-        for x in c:
-            write_with_length(dump_pickle(x), tempFile)
+            serializer = BatchedSerializer(self._unbatched_serializer,
+                                           batchSize)
+        else:
+            serializer = self._unbatched_serializer
+        serializer.dump_stream(c, tempFile)
         tempFile.close()
         readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
         jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
-        return RDD(jrdd, self)
+        return RDD(jrdd, self, serializer)
 
     def textFile(self, name, minSplits=None):
         """
@@ -201,21 +210,39 @@ class SparkContext(object):
         RDD of Strings.
         """
         minSplits = minSplits or min(self.defaultParallelism, 2)
-        jrdd = self._jsc.textFile(name, minSplits)
-        return RDD(jrdd, self)
+        return RDD(self._jsc.textFile(name, minSplits), self,
+                   MUTF8Deserializer())
 
-    def _checkpointFile(self, name):
+    def _checkpointFile(self, name, input_deserializer):
         jrdd = self._jsc.checkpointFile(name)
-        return RDD(jrdd, self)
+        return RDD(jrdd, self, input_deserializer)
 
     def union(self, rdds):
         """
         Build the union of a list of RDDs.
+
+        This supports unions() of RDDs with different serialized formats,
+        although this forces them to be reserialized using the default
+        serializer:
+
+        >>> path = os.path.join(tempdir, "union-text.txt")
+        >>> with open(path, "w") as testFile:
+        ...    testFile.write("Hello")
+        >>> textFile = sc.textFile(path)
+        >>> textFile.collect()
+        [u'Hello']
+        >>> parallelized = sc.parallelize(["World!"])
+        >>> sorted(sc.union([textFile, parallelized]).collect())
+        [u'Hello', 'World!']
         """
+        first_jrdd_deserializer = rdds[0]._jrdd_deserializer
+        if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
+            rdds = [x._reserialize() for x in rdds]
         first = rdds[0]._jrdd
         rest = [x._jrdd for x in rdds[1:]]
-        rest = ListConverter().convert(rest, self.gateway._gateway_client)
-        return RDD(self._jsc.union(first, rest), self)
+        rest = ListConverter().convert(rest, self._gateway._gateway_client)
+        return RDD(self._jsc.union(first, rest), self,
+                   rdds[0]._jrdd_deserializer)
 
     def broadcast(self, value):
         """
@@ -223,7 +250,9 @@ class SparkContext(object):
         object for reading it in distributed functions. The variable will be
         sent to each cluster only once.
         """
-        jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
+        pickleSer = PickleSerializer()
+        pickled = pickleSer._dumps(value)
+        jbroadcast = self._jsc.broadcast(bytearray(pickled))
         return Broadcast(jbroadcast.id(), value, jbroadcast,
                          self._pickled_broadcast_vars)
 
@@ -235,7 +264,7 @@ class SparkContext(object):
         and floating-point numbers if you do not provide one. For other types,
         a custom AccumulatorParam can be used.
         """
-        if accum_param == None:
+        if accum_param is None:
             if isinstance(value, int):
                 accum_param = accumulators.INT_ACCUMULATOR_PARAM
             elif isinstance(value, float):
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index d3c4d13a1e0a28d5a66abd97442f67fba1686374..6691c30519db62e14ac661ae49d02336f5e92ada 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -18,7 +18,7 @@
 from base64 import standard_b64encode as b64enc
 import copy
 from collections import defaultdict
-from itertools import chain, ifilter, imap, product
+from itertools import chain, ifilter, imap
 import operator
 import os
 import sys
@@ -28,8 +28,8 @@ from tempfile import NamedTemporaryFile
 from threading import Thread
 
 from pyspark import cloudpickle
-from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
-    read_from_pickle_file, pack_long
+from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
+    BatchedSerializer, pack_long
 from pyspark.join import python_join, python_left_outer_join, \
     python_right_outer_join, python_cogroup
 from pyspark.statcounter import StatCounter
@@ -48,13 +48,12 @@ class RDD(object):
     operated on in parallel.
     """
 
-    def __init__(self, jrdd, ctx):
+    def __init__(self, jrdd, ctx, jrdd_deserializer):
         self._jrdd = jrdd
         self.is_cached = False
         self.is_checkpointed = False
         self.ctx = ctx
-        self._partitionFunc = None
-        self._stage_input_is_pairs = False
+        self._jrdd_deserializer = jrdd_deserializer
 
     @property
     def context(self):
@@ -248,7 +247,23 @@ class RDD(object):
         >>> rdd.union(rdd).collect()
         [1, 1, 2, 3, 1, 1, 2, 3]
         """
-        return RDD(self._jrdd.union(other._jrdd), self.ctx)
+        if self._jrdd_deserializer == other._jrdd_deserializer:
+            rdd = RDD(self._jrdd.union(other._jrdd), self.ctx,
+                      self._jrdd_deserializer)
+            return rdd
+        else:
+            # These RDDs contain data in different serialized formats, so we
+            # must normalize them to the default serializer.
+            self_copy = self._reserialize()
+            other_copy = other._reserialize()
+            return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
+                       self.ctx.serializer)
+
+    def _reserialize(self):
+        if self._jrdd_deserializer == self.ctx.serializer:
+            return self
+        else:
+            return self.map(lambda x: x, preservesPartitioning=True)
 
     def __add__(self, other):
         """
@@ -335,18 +350,9 @@ class RDD(object):
         [(1, 1), (1, 2), (2, 1), (2, 2)]
         """
         # Due to batching, we can't use the Java cartesian method.
-        java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
-        def unpack_batches(pair):
-            (x, y) = pair
-            if type(x) == Batch or type(y) == Batch:
-                xs = x.items if type(x) == Batch else [x]
-                ys = y.items if type(y) == Batch else [y]
-                for pair in product(xs, ys):
-                    yield pair
-            else:
-                yield pair
-        java_cartesian._stage_input_is_pairs = True
-        return java_cartesian.flatMap(unpack_batches)
+        deserializer = CartesianDeserializer(self._jrdd_deserializer,
+                                             other._jrdd_deserializer)
+        return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer)
 
     def groupBy(self, f, numPartitions=None):
         """
@@ -405,7 +411,7 @@ class RDD(object):
         self.ctx._writeToFile(iterator, tempFile.name)
         # Read the data into Python and deserialize it:
         with open(tempFile.name, 'rb') as tempFile:
-            for item in read_from_pickle_file(tempFile):
+            for item in self._jrdd_deserializer.load_stream(tempFile):
                 yield item
         os.unlink(tempFile.name)
 
@@ -573,7 +579,7 @@ class RDD(object):
         items = []
         for partition in range(mapped._jrdd.splits().size()):
             iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
-            items.extend(self._collect_iterator_through_file(iterator))
+            items.extend(mapped._collect_iterator_through_file(iterator))
             if len(items) >= num:
                 break
         return items[:num]
@@ -737,6 +743,7 @@ class RDD(object):
         # Transferring O(n) objects to Java is too expensive.  Instead, we'll
         # form the hash buckets in Python, transferring O(numPartitions) objects
         # to Java.  Each object is a (splitNumber, [objects]) pair.
+        outputSerializer = self.ctx._unbatched_serializer
         def add_shuffle_key(split, iterator):
 
             buckets = defaultdict(list)
@@ -745,14 +752,14 @@ class RDD(object):
                 buckets[partitionFunc(k) % numPartitions].append((k, v))
             for (split, items) in buckets.iteritems():
                 yield pack_long(split)
-                yield dump_pickle(Batch(items))
+                yield outputSerializer._dumps(items)
         keyed = PipelinedRDD(self, add_shuffle_key)
         keyed._bypass_serializer = True
         pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
         partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
                                                      id(partitionFunc))
         jrdd = pairRDD.partitionBy(partitioner).values()
-        rdd = RDD(jrdd, self.ctx)
+        rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
         # This is required so that id(partitionFunc) remains unique, even if
         # partitionFunc is a lambda:
         rdd._partitionFunc = partitionFunc
@@ -789,7 +796,8 @@ class RDD(object):
             numPartitions = self.ctx.defaultParallelism
         def combineLocally(iterator):
             combiners = {}
-            for (k, v) in iterator:
+            for x in iterator:
+                (k, v) = x
                 if k not in combiners:
                     combiners[k] = createCombiner(v)
                 else:
@@ -931,38 +939,38 @@ class PipelinedRDD(RDD):
     20
     """
     def __init__(self, prev, func, preservesPartitioning=False):
-        if isinstance(prev, PipelinedRDD) and prev._is_pipelinable():
+        if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable():
+            # This transformation is the first in its stage:
+            self.func = func
+            self.preservesPartitioning = preservesPartitioning
+            self._prev_jrdd = prev._jrdd
+            self._prev_jrdd_deserializer = prev._jrdd_deserializer
+        else:
             prev_func = prev.func
             def pipeline_func(split, iterator):
                 return func(split, prev_func(split, iterator))
             self.func = pipeline_func
             self.preservesPartitioning = \
                 prev.preservesPartitioning and preservesPartitioning
-            self._prev_jrdd = prev._prev_jrdd
-        else:
-            self.func = func
-            self.preservesPartitioning = preservesPartitioning
-            self._prev_jrdd = prev._jrdd
-        self._stage_input_is_pairs = prev._stage_input_is_pairs
+            self._prev_jrdd = prev._prev_jrdd  # maintain the pipeline
+            self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer
         self.is_cached = False
         self.is_checkpointed = False
         self.ctx = prev.ctx
         self.prev = prev
         self._jrdd_val = None
+        self._jrdd_deserializer = self.ctx.serializer
         self._bypass_serializer = False
 
     @property
     def _jrdd(self):
         if self._jrdd_val:
             return self._jrdd_val
-        func = self.func
-        if not self._bypass_serializer and self.ctx.batchSize != 1:
-            oldfunc = self.func
-            batchSize = self.ctx.batchSize
-            def batched_func(split, iterator):
-                return batched(oldfunc(split, iterator), batchSize)
-            func = batched_func
-        cmds = [func, self._bypass_serializer, self._stage_input_is_pairs]
+        if self._bypass_serializer:
+            serializer = NoOpSerializer()
+        else:
+            serializer = self.ctx.serializer
+        cmds = [self.func, self._prev_jrdd_deserializer, serializer]
         pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
         broadcast_vars = ListConverter().convert(
             [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index fd02e1ee8fbd7a8a83e2249f08da965f97703362..4fb444443f7849c7baa908b9b1b1adcbc3194318 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -15,8 +15,58 @@
 # limitations under the License.
 #
 
-import struct
+"""
+PySpark supports custom serializers for transferring data; this can improve
+performance.
+
+By default, PySpark uses L{PickleSerializer} to serialize objects using Python's
+C{cPickle} serializer, which can serialize nearly any Python object.
+Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be
+faster.
+
+The serializer is chosen when creating L{SparkContext}:
+
+>>> from pyspark.context import SparkContext
+>>> from pyspark.serializers import MarshalSerializer
+>>> sc = SparkContext('local', 'test', serializer=MarshalSerializer())
+>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
+[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
+>>> sc.stop()
+
+By default, PySpark serialize objects in batches; the batch size can be
+controlled through SparkContext's C{batchSize} parameter
+(the default size is 1024 objects):
+
+>>> sc = SparkContext('local', 'test', batchSize=2)
+>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
+
+Behind the scenes, this creates a JavaRDD with four partitions, each of
+which contains two batches of two objects:
+
+>>> rdd.glom().collect()
+[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
+>>> rdd._jrdd.count()
+8L
+>>> sc.stop()
+
+A batch size of -1 uses an unlimited batch size, and a size of 1 disables
+batching:
+
+>>> sc = SparkContext('local', 'test', batchSize=1)
+>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
+>>> rdd.glom().collect()
+[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
+>>> rdd._jrdd.count()
+16L
+"""
+
 import cPickle
+from itertools import chain, izip, product
+import marshal
+import struct
+
+
+__all__ = ["PickleSerializer", "MarshalSerializer"]
 
 
 class SpecialLengths(object):
@@ -25,41 +75,206 @@ class SpecialLengths(object):
     TIMING_DATA = -3
 
 
-class Batch(object):
+class Serializer(object):
+
+    def dump_stream(self, iterator, stream):
+        """
+        Serialize an iterator of objects to the output stream.
+        """
+        raise NotImplementedError
+
+    def load_stream(self, stream):
+        """
+        Return an iterator of deserialized objects from the input stream.
+        """
+        raise NotImplementedError
+
+
+    def _load_stream_without_unbatching(self, stream):
+        return self.load_stream(stream)
+
+    # Note: our notion of "equality" is that output generated by
+    # equal serializers can be deserialized using the same serializer.
+
+    # This default implementation handles the simple cases;
+    # subclasses should override __eq__ as appropriate.
+
+    def __eq__(self, other):
+        return isinstance(other, self.__class__)
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+
+class FramedSerializer(Serializer):
+    """
+    Serializer that writes objects as a stream of (length, data) pairs,
+    where C{length} is a 32-bit integer and data is C{length} bytes.
+    """
+
+    def dump_stream(self, iterator, stream):
+        for obj in iterator:
+            self._write_with_length(obj, stream)
+
+    def load_stream(self, stream):
+        while True:
+            try:
+                yield self._read_with_length(stream)
+            except EOFError:
+                return
+
+    def _write_with_length(self, obj, stream):
+        serialized = self._dumps(obj)
+        write_int(len(serialized), stream)
+        stream.write(serialized)
+
+    def _read_with_length(self, stream):
+        length = read_int(stream)
+        obj = stream.read(length)
+        if obj == "":
+            raise EOFError
+        return self._loads(obj)
+
+    def _dumps(self, obj):
+        """
+        Serialize an object into a byte array.
+        When batching is used, this will be called with an array of objects.
+        """
+        raise NotImplementedError
+
+    def _loads(self, obj):
+        """
+        Deserialize an object from a byte array.
+        """
+        raise NotImplementedError
+
+
+class BatchedSerializer(Serializer):
+    """
+    Serializes a stream of objects in batches by calling its wrapped
+    Serializer with streams of objects.
+    """
+
+    UNLIMITED_BATCH_SIZE = -1
+
+    def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE):
+        self.serializer = serializer
+        self.batchSize = batchSize
+
+    def _batched(self, iterator):
+        if self.batchSize == self.UNLIMITED_BATCH_SIZE:
+            yield list(iterator)
+        else:
+            items = []
+            count = 0
+            for item in iterator:
+                items.append(item)
+                count += 1
+                if count == self.batchSize:
+                    yield items
+                    items = []
+                    count = 0
+            if items:
+                yield items
+
+    def dump_stream(self, iterator, stream):
+        if isinstance(iterator, basestring):
+            iterator = [iterator]
+        self.serializer.dump_stream(self._batched(iterator), stream)
+
+    def load_stream(self, stream):
+        return chain.from_iterable(self._load_stream_without_unbatching(stream))
+
+    def _load_stream_without_unbatching(self, stream):
+            return self.serializer.load_stream(stream)
+
+    def __eq__(self, other):
+        return isinstance(other, BatchedSerializer) and \
+               other.serializer == self.serializer
+
+    def __str__(self):
+        return "BatchedSerializer<%s>" % str(self.serializer)
+
+
+class CartesianDeserializer(FramedSerializer):
     """
-    Used to store multiple RDD entries as a single Java object.
+    Deserializes the JavaRDD cartesian() of two PythonRDDs.
+    """
+
+    def __init__(self, key_ser, val_ser):
+        self.key_ser = key_ser
+        self.val_ser = val_ser
+
+    def load_stream(self, stream):
+        key_stream = self.key_ser._load_stream_without_unbatching(stream)
+        val_stream = self.val_ser._load_stream_without_unbatching(stream)
+        key_is_batched = isinstance(self.key_ser, BatchedSerializer)
+        val_is_batched = isinstance(self.val_ser, BatchedSerializer)
+        for (keys, vals) in izip(key_stream, val_stream):
+            keys = keys if key_is_batched else [keys]
+            vals = vals if val_is_batched else [vals]
+            for pair in product(keys, vals):
+                yield pair
+
+    def __eq__(self, other):
+        return isinstance(other, CartesianDeserializer) and \
+               self.key_ser == other.key_ser and self.val_ser == other.val_ser
+
+    def __str__(self):
+        return "CartesianDeserializer<%s, %s>" % \
+               (str(self.key_ser), str(self.val_ser))
+
+
+class NoOpSerializer(FramedSerializer):
+
+    def _loads(self, obj): return obj
+    def _dumps(self, obj): return obj
+
+
+class PickleSerializer(FramedSerializer):
+    """
+    Serializes objects using Python's cPickle serializer:
+
+        http://docs.python.org/2/library/pickle.html
+
+    This serializer supports nearly any Python object, but may
+    not be as fast as more specialized serializers.
+    """
+
+    def _dumps(self, obj): return cPickle.dumps(obj, 2)
+    _loads = cPickle.loads
+
 
-    This relieves us from having to explicitly track whether an RDD
-    is stored as batches of objects and avoids problems when processing
-    the union() of batched and unbatched RDDs (e.g. the union() of textFile()
-    with another RDD).
+class MarshalSerializer(FramedSerializer):
     """
-    def __init__(self, items):
-        self.items = items
+    Serializes objects using Python's Marshal serializer:
 
+        http://docs.python.org/2/library/marshal.html
 
-def batched(iterator, batchSize):
-    if batchSize == -1: # unlimited batch size
-        yield Batch(list(iterator))
-    else:
-        items = []
-        count = 0
-        for item in iterator:
-            items.append(item)
-            count += 1
-            if count == batchSize:
-                yield Batch(items)
-                items = []
-                count = 0
-        if items:
-            yield Batch(items)
+    This serializer is faster than PickleSerializer but supports fewer datatypes.
+    """
+
+    _dumps = marshal.dumps
+    _loads = marshal.loads
 
 
-def dump_pickle(obj):
-    return cPickle.dumps(obj, 2)
+class MUTF8Deserializer(Serializer):
+    """
+    Deserializes streams written by Java's DataOutputStream.writeUTF().
+    """
 
+    def _loads(self, stream):
+        length = struct.unpack('>H', stream.read(2))[0]
+        return stream.read(length).decode('utf8')
 
-load_pickle = cPickle.loads
+    def load_stream(self, stream):
+        while True:
+            try:
+                yield self._loads(stream)
+            except struct.error:
+                return
+            except EOFError:
+                return
 
 
 def read_long(stream):
@@ -90,43 +305,4 @@ def write_int(value, stream):
 
 def write_with_length(obj, stream):
     write_int(len(obj), stream)
-    stream.write(obj)
-
-
-def read_mutf8(stream):
-    """
-    Read a string written with Java's DataOutputStream.writeUTF() method.
-    """
-    length = struct.unpack('>H', stream.read(2))[0]
-    return stream.read(length).decode('utf8')
-
-
-def read_with_length(stream):
-    length = read_int(stream)
-    obj = stream.read(length)
-    if obj == "":
-        raise EOFError
-    return obj
-
-
-def read_from_pickle_file(stream):
-    try:
-        while True:
-            obj = load_pickle(read_with_length(stream))
-            if type(obj) == Batch:  # We don't care about inheritance
-                for item in obj.items:
-                    yield item
-            else:
-                yield obj
-    except EOFError:
-        return
-
-
-def read_pairs_from_pickle_file(stream):
-    try:
-        while True:
-            a = load_pickle(read_with_length(stream))
-            b = load_pickle(read_with_length(stream))
-            yield (a, b)
-    except EOFError:
-        return
\ No newline at end of file
+    stream.write(obj)
\ No newline at end of file
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 29d6a128f6a9b83dc742c676ef010a90f54ab73e..621e1cb58c3df10afa1f64a8f7b9f988dd71b0cb 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -86,7 +86,8 @@ class TestCheckpoint(PySparkTestCase):
         time.sleep(1)  # 1 second
 
         self.assertTrue(flatMappedRDD.getCheckpointFile() is not None)
-        recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
+        recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(),
+                                            flatMappedRDD._jrdd_deserializer)
         self.assertEquals([1, 2, 3, 4], recovered.collect())
 
 
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 4e64557fc49e414657f14cdfe645807f7e5bffbd..5b16d5db7e859bd4b008063ca886565b6a73a2c6 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -30,13 +30,17 @@ from pyspark.accumulators import _accumulatorRegistry
 from pyspark.broadcast import Broadcast, _broadcastRegistry
 from pyspark.cloudpickle import CloudPickler
 from pyspark.files import SparkFiles
-from pyspark.serializers import write_with_length, read_with_length, write_int, \
-    read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file, \
-    SpecialLengths, read_mutf8, read_pairs_from_pickle_file
+from pyspark.serializers import write_with_length, write_int, read_long, \
+    write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer
+
+
+pickleSer = PickleSerializer()
+mutf8_deserializer = MUTF8Deserializer()
 
 
 def load_obj(infile):
-    return load_pickle(standard_b64decode(infile.readline().strip()))
+    decoded = standard_b64decode(infile.readline().strip())
+    return pickleSer._loads(decoded)
 
 
 def report_times(outfile, boot, init, finish):
@@ -53,7 +57,7 @@ def main(infile, outfile):
         return
 
     # fetch name of workdir
-    spark_files_dir = read_mutf8(infile)
+    spark_files_dir = mutf8_deserializer._loads(infile)
     SparkFiles._root_directory = spark_files_dir
     SparkFiles._is_running_on_worker = True
 
@@ -61,31 +65,24 @@ def main(infile, outfile):
     num_broadcast_variables = read_int(infile)
     for _ in range(num_broadcast_variables):
         bid = read_long(infile)
-        value = read_with_length(infile)
-        _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
+        value = pickleSer._read_with_length(infile)
+        _broadcastRegistry[bid] = Broadcast(bid, value)
 
     # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
     sys.path.append(spark_files_dir) # *.py files that were added will be copied here
     num_python_includes =  read_int(infile)
     for _ in range(num_python_includes):
-        sys.path.append(os.path.join(spark_files_dir, read_mutf8(infile)))
+        filename = mutf8_deserializer._loads(infile)
+        sys.path.append(os.path.join(spark_files_dir, filename))
 
-    # now load function
+    # Load this stage's function and serializer:
     func = load_obj(infile)
-    bypassSerializer = load_obj(infile)
-    stageInputIsPairs = load_obj(infile)
-    if bypassSerializer:
-        dumps = lambda x: x
-    else:
-        dumps = dump_pickle
+    deserializer = load_obj(infile)
+    serializer = load_obj(infile)
     init_time = time.time()
-    if stageInputIsPairs:
-        iterator = read_pairs_from_pickle_file(infile)
-    else:
-        iterator = read_from_pickle_file(infile)
     try:
-        for obj in func(split_index, iterator):
-            write_with_length(dumps(obj), outfile)
+        iterator = deserializer.load_stream(infile)
+        serializer.dump_stream(func(split_index, iterator), outfile)
     except Exception as e:
         write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
         write_with_length(traceback.format_exc(), outfile)
@@ -96,7 +93,7 @@ def main(infile, outfile):
     write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
     write_int(len(_accumulatorRegistry), outfile)
     for (aid, accum) in _accumulatorRegistry.items():
-        write_with_length(dump_pickle((aid, accum._value)), outfile)
+        pickleSer._write_with_length((aid, accum._value), outfile)
 
 
 if __name__ == '__main__':
diff --git a/python/run-tests b/python/run-tests
index cbc554ea9db0d2cdd18323a408879d98f80810cb..d4dad672d284299a0a1fc116b24f7833eefe1b7b 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -37,6 +37,7 @@ run_test "pyspark/rdd.py"
 run_test "pyspark/context.py"
 run_test "-m doctest pyspark/broadcast.py"
 run_test "-m doctest pyspark/accumulators.py"
+run_test "-m doctest pyspark/serializers.py"
 run_test "pyspark/tests.py"
 
 if [[ $FAILED != 0 ]]; then