diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 7409dc2d866f69a88be0b2871fd5005d5756646b..2d92f6a42b308a4d8787d60602de18eb3df3701f 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -47,6 +47,7 @@ private[spark] class PythonRDD(
     pythonIncludes: JList[String],
     preservePartitoning: Boolean,
     pythonExec: String,
+    pythonVer: String,
     broadcastVars: JList[Broadcast[PythonBroadcast]],
     accumulator: Accumulator[JList[Array[Byte]]])
   extends RDD[Array[Byte]](parent) {
@@ -210,6 +211,8 @@ private[spark] class PythonRDD(
         val dataOut = new DataOutputStream(stream)
         // Partition index
         dataOut.writeInt(split.index)
+        // Python version of driver
+        PythonRDD.writeUTF(pythonVer, dataOut)
         // sparkFilesDir
         PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
         // Python includes (*.zip and *.egg files)
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 31992795a9e4562162da3d5423f6483ae6683788..d25ee855235bef41743183b5a23c1905b8042f9c 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -173,6 +173,7 @@ class SparkContext(object):
             self._jvm.PythonAccumulatorParam(host, port))
 
         self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
+        self.pythonVer = "%d.%d" % sys.version_info[:2]
 
         # Broadcast's __reduce__ method stores Broadcast instances here.
         # This allows other code to determine which Broadcast instances have
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 545c5ad20cb96cbed06103522c40dea177e52b44..70db4bbe4cbc5f45f414072cbf689259a768be14 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2260,7 +2260,7 @@ class RDD(object):
 def _prepare_for_python_RDD(sc, command, obj=None):
     # the serialized command will be compressed by broadcast
     ser = CloudPickleSerializer()
-    pickled_command = ser.dumps((command, sys.version_info[:2]))
+    pickled_command = ser.dumps(command)
     if len(pickled_command) > (1 << 20):  # 1M
         # The broadcast will have same life cycle as created PythonRDD
         broadcast = sc.broadcast(pickled_command)
@@ -2344,7 +2344,7 @@ class PipelinedRDD(RDD):
         python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
                                              bytearray(pickled_cmd),
                                              env, includes, self.preservesPartitioning,
-                                             self.ctx.pythonExec,
+                                             self.ctx.pythonExec, self.ctx.pythonVer,
                                              bvars, self.ctx._javaAccumulator)
         self._jrdd_val = python_rdd.asJavaRDD()
 
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index f6f107ca32d2f52b9e6b196f303de142129d6575..0bde7191242ab30a3d13878f7d6cadbcd2c8a11d 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -157,6 +157,7 @@ class SQLContext(object):
                                             env,
                                             includes,
                                             self._sc.pythonExec,
+                                            self._sc.pythonVer,
                                             bvars,
                                             self._sc._javaAccumulator,
                                             returnType.json())
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 8d0e766ecd3b4bbb3fb7789286e820ca8f2bbb36..fbe9bf5b526af233ae7b3a8bf8a3300dc65ac3bd 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -353,8 +353,8 @@ class UserDefinedFunction(object):
         ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
         jdt = ssql_ctx.parseDataType(self.returnType.json())
         fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
-        judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env,
-                                                 includes, sc.pythonExec, broadcast_vars,
+        judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes,
+                                                 sc.pythonExec, sc.pythonVer, broadcast_vars,
                                                  sc._javaAccumulator, jdt)
         return judf
 
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 09de4d159fdcf4a65981b1a697ef3b9fdd1d9821..5e023f6c53517fc15e11f9ee20800710b44d03d8 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1543,13 +1543,13 @@ class WorkerTests(ReusedPySparkTestCase):
     def test_with_different_versions_of_python(self):
         rdd = self.sc.parallelize(range(10))
         rdd.count()
-        version = sys.version_info
-        sys.version_info = (2, 0, 0)
+        version = self.sc.pythonVer
+        self.sc.pythonVer = "2.0"
         try:
             with QuietTest(self.sc):
                 self.assertRaises(Py4JJavaError, lambda: rdd.count())
         finally:
-            sys.version_info = version
+            self.sc.pythonVer = version
 
 
 class SparkSubmitTests(unittest.TestCase):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index fbdaf3a5814cd921ba6b97cd883d492999d65362..93df9002be377ba66cba632908a8dc6e9dbb18c7 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -57,6 +57,12 @@ def main(infile, outfile):
         if split_index == -1:  # for unit tests
             exit(-1)
 
+        version = utf8_deserializer.loads(infile)
+        if version != "%d.%d" % sys.version_info[:2]:
+            raise Exception(("Python in worker has different version %s than that in " +
+                             "driver %s, PySpark cannot run with different minor versions") %
+                            ("%d.%d" % sys.version_info[:2], version))
+
         # initialize global state
         shuffle.MemoryBytesSpilled = 0
         shuffle.DiskBytesSpilled = 0
@@ -92,11 +98,7 @@ def main(infile, outfile):
         command = pickleSer._read_with_length(infile)
         if isinstance(command, Broadcast):
             command = pickleSer.loads(command.value)
-        (func, profiler, deserializer, serializer), version = command
-        if version != sys.version_info[:2]:
-            raise Exception(("Python in worker has different version %s than that in " +
-                            "driver %s, PySpark cannot run with different minor versions") %
-                            (sys.version_info[:2], version))
+        func, profiler, deserializer, serializer = command
         init_time = time.time()
 
         def process():
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index dc3389c41bbfa81dd03cfecf0e0c3acfb7558dea..3cc5c2441d8a5c182d386cb60bdb72e57b8f77e1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -46,6 +46,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
       envVars: JMap[String, String],
       pythonIncludes: JList[String],
       pythonExec: String,
+      pythonVer: String,
       broadcastVars: JList[Broadcast[PythonBroadcast]],
       accumulator: Accumulator[JList[Array[Byte]]],
       stringDataType: String): Unit = {
@@ -70,6 +71,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
         envVars,
         pythonIncludes,
         pythonExec,
+        pythonVer,
         broadcastVars,
         accumulator,
         dataType,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
index 505ab1301ec9673837afa264727d6ff0d38c58c9..a02e202d2eebcc96c8215013f87615ee410ddf87 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
@@ -58,14 +58,15 @@ private[sql] case class UserDefinedPythonFunction(
     envVars: JMap[String, String],
     pythonIncludes: JList[String],
     pythonExec: String,
+    pythonVer: String,
     broadcastVars: JList[Broadcast[PythonBroadcast]],
     accumulator: Accumulator[JList[Array[Byte]]],
     dataType: DataType) {
 
   /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
   def apply(exprs: Column*): Column = {
-    val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, broadcastVars,
-      accumulator, dataType, exprs.map(_.expr))
+    val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer,
+      broadcastVars, accumulator, dataType, exprs.map(_.expr))
     Column(udf)
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 65dd7ba020fa3502b3be9f89ad6a5d7accdbd059..11b2897f76786bafc79c15a91f066066ca6ac57d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -46,6 +46,7 @@ private[spark] case class PythonUDF(
     envVars: JMap[String, String],
     pythonIncludes: JList[String],
     pythonExec: String,
+    pythonVer: String,
     broadcastVars: JList[Broadcast[PythonBroadcast]],
     accumulator: Accumulator[JList[Array[Byte]]],
     dataType: DataType,
@@ -251,6 +252,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
       udf.pythonIncludes,
       false,
       udf.pythonExec,
+      udf.pythonVer,
       udf.broadcastVars,
       udf.accumulator
     ).mapPartitions { iter =>