Skip to content
Snippets Groups Projects
Commit 1381fc72 authored by Josh Rosen's avatar Josh Rosen
Browse files

Switch from MUTF8 to UTF8 in PySpark serializers.

This fixes SPARK-1043, a bug introduced in 0.9.0
where PySpark couldn't serialize strings > 64kB.

This fix was written by @tyro89 and @bouk in #512.
This commit squashes and rebases their pull request
in order to fix some merge conflicts.
parent 84670f27
No related branches found
No related tags found
No related merge requests found
...@@ -64,7 +64,7 @@ private[spark] class PythonRDD[T: ClassTag]( ...@@ -64,7 +64,7 @@ private[spark] class PythonRDD[T: ClassTag](
// Partition index // Partition index
dataOut.writeInt(split.index) dataOut.writeInt(split.index)
// sparkFilesDir // sparkFilesDir
dataOut.writeUTF(SparkFiles.getRootDirectory) PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
// Broadcast variables // Broadcast variables
dataOut.writeInt(broadcastVars.length) dataOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) { for (broadcast <- broadcastVars) {
...@@ -74,7 +74,9 @@ private[spark] class PythonRDD[T: ClassTag]( ...@@ -74,7 +74,9 @@ private[spark] class PythonRDD[T: ClassTag](
} }
// Python includes (*.zip and *.egg files) // Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.length) dataOut.writeInt(pythonIncludes.length)
pythonIncludes.foreach(dataOut.writeUTF) for (include <- pythonIncludes) {
PythonRDD.writeUTF(include, dataOut)
}
dataOut.flush() dataOut.flush()
// Serialized command: // Serialized command:
dataOut.writeInt(command.length) dataOut.writeInt(command.length)
...@@ -228,7 +230,7 @@ private[spark] object PythonRDD { ...@@ -228,7 +230,7 @@ private[spark] object PythonRDD {
} }
case string: String => case string: String =>
newIter.asInstanceOf[Iterator[String]].foreach { str => newIter.asInstanceOf[Iterator[String]].foreach { str =>
dataOut.writeUTF(str) writeUTF(str, dataOut)
} }
case pair: Tuple2[_, _] => case pair: Tuple2[_, _] =>
pair._1 match { pair._1 match {
...@@ -241,8 +243,8 @@ private[spark] object PythonRDD { ...@@ -241,8 +243,8 @@ private[spark] object PythonRDD {
} }
case stringPair: String => case stringPair: String =>
newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair => newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair =>
dataOut.writeUTF(pair._1) writeUTF(pair._1, dataOut)
dataOut.writeUTF(pair._2) writeUTF(pair._2, dataOut)
} }
case other => case other =>
throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass) throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass)
...@@ -253,6 +255,12 @@ private[spark] object PythonRDD { ...@@ -253,6 +255,12 @@ private[spark] object PythonRDD {
} }
} }
def writeUTF(str: String, dataOut: DataOutputStream) {
val bytes = str.getBytes("UTF-8")
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
}
def writeToFile[T](items: java.util.Iterator[T], filename: String) { def writeToFile[T](items: java.util.Iterator[T], filename: String) {
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
writeToFile(items.asScala, filename) writeToFile(items.asScala, filename)
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.python
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
import org.apache.spark.api.python.PythonRDD
import java.io.{ByteArrayOutputStream, DataOutputStream}
class PythonRDDSuite extends FunSuite {
test("Writing large strings to the worker") {
val input: List[String] = List("a"*100000)
val buffer = new DataOutputStream(new ByteArrayOutputStream)
PythonRDD.writeIteratorToStream(input.iterator, buffer)
}
}
...@@ -27,7 +27,7 @@ from pyspark.broadcast import Broadcast ...@@ -27,7 +27,7 @@ from pyspark.broadcast import Broadcast
from pyspark.conf import SparkConf from pyspark.conf import SparkConf
from pyspark.files import SparkFiles from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD from pyspark.rdd import RDD
...@@ -234,7 +234,7 @@ class SparkContext(object): ...@@ -234,7 +234,7 @@ class SparkContext(object):
""" """
minSplits = minSplits or min(self.defaultParallelism, 2) minSplits = minSplits or min(self.defaultParallelism, 2)
return RDD(self._jsc.textFile(name, minSplits), self, return RDD(self._jsc.textFile(name, minSplits), self,
MUTF8Deserializer()) UTF8Deserializer())
def _checkpointFile(self, name, input_deserializer): def _checkpointFile(self, name, input_deserializer):
jrdd = self._jsc.checkpointFile(name) jrdd = self._jsc.checkpointFile(name)
......
...@@ -261,13 +261,13 @@ class MarshalSerializer(FramedSerializer): ...@@ -261,13 +261,13 @@ class MarshalSerializer(FramedSerializer):
loads = marshal.loads loads = marshal.loads
class MUTF8Deserializer(Serializer): class UTF8Deserializer(Serializer):
""" """
Deserializes streams written by Java's DataOutputStream.writeUTF(). Deserializes streams written by getBytes.
""" """
def loads(self, stream): def loads(self, stream):
length = struct.unpack('>H', stream.read(2))[0] length = read_int(stream)
return stream.read(length).decode('utf8') return stream.read(length).decode('utf8')
def load_stream(self, stream): def load_stream(self, stream):
......
...@@ -30,11 +30,11 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry ...@@ -30,11 +30,11 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \ from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer
pickleSer = PickleSerializer() pickleSer = PickleSerializer()
mutf8_deserializer = MUTF8Deserializer() utf8_deserializer = UTF8Deserializer()
def report_times(outfile, boot, init, finish): def report_times(outfile, boot, init, finish):
...@@ -51,7 +51,7 @@ def main(infile, outfile): ...@@ -51,7 +51,7 @@ def main(infile, outfile):
return return
# fetch name of workdir # fetch name of workdir
spark_files_dir = mutf8_deserializer.loads(infile) spark_files_dir = utf8_deserializer.loads(infile)
SparkFiles._root_directory = spark_files_dir SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True SparkFiles._is_running_on_worker = True
...@@ -66,7 +66,7 @@ def main(infile, outfile): ...@@ -66,7 +66,7 @@ def main(infile, outfile):
sys.path.append(spark_files_dir) # *.py files that were added will be copied here sys.path.append(spark_files_dir) # *.py files that were added will be copied here
num_python_includes = read_int(infile) num_python_includes = read_int(infile)
for _ in range(num_python_includes): for _ in range(num_python_includes):
filename = mutf8_deserializer.loads(infile) filename = utf8_deserializer.loads(infile)
sys.path.append(os.path.join(spark_files_dir, filename)) sys.path.append(os.path.join(spark_files_dir, filename))
command = pickleSer._read_with_length(infile) command = pickleSer._read_with_length(infile)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment