Skip to content
Snippets Groups Projects
Commit 60e18ce7 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

SPARK-1414. Python API for SparkContext.wholeTextFiles

Also clarified comment on each file having to fit in memory

Author: Matei Zaharia <matei@databricks.com>

Closes #327 from mateiz/py-whole-files and squashes the following commits:

9ad64a5 [Matei Zaharia] SPARK-1414. Python API for SparkContext.wholeTextFiles
parent d956cc25
No related branches found
No related tags found
No related merge requests found
...@@ -395,7 +395,7 @@ class SparkContext( ...@@ -395,7 +395,7 @@ class SparkContext(
* (a-hdfs-path/part-nnnnn, its content) * (a-hdfs-path/part-nnnnn, its content)
* }}} * }}}
* *
* @note Small files are perferred, large file is also allowable, but may cause bad performance. * @note Small files are preferred, as each file will be loaded fully in memory.
*/ */
def wholeTextFiles(path: String): RDD[(String, String)] = { def wholeTextFiles(path: String): RDD[(String, String)] = {
newAPIHadoopFile( newAPIHadoopFile(
......
...@@ -177,7 +177,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork ...@@ -177,7 +177,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* (a-hdfs-path/part-nnnnn, its content) * (a-hdfs-path/part-nnnnn, its content)
* }}} * }}}
* *
* @note Small files are perferred, large file is also allowable, but may cause bad performance. * @note Small files are preferred, as each file will be loaded fully in memory.
*/ */
def wholeTextFiles(path: String): JavaPairRDD[String, String] = def wholeTextFiles(path: String): JavaPairRDD[String, String] =
new JavaPairRDD(sc.wholeTextFiles(path)) new JavaPairRDD(sc.wholeTextFiles(path))
......
...@@ -19,6 +19,7 @@ package org.apache.spark.api.python ...@@ -19,6 +19,7 @@ package org.apache.spark.api.python
import java.io._ import java.io._
import java.net._ import java.net._
import java.nio.charset.Charset
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
...@@ -206,6 +207,7 @@ private object SpecialLengths { ...@@ -206,6 +207,7 @@ private object SpecialLengths {
} }
private[spark] object PythonRDD { private[spark] object PythonRDD {
val UTF8 = Charset.forName("UTF-8")
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = { JavaRDD[Array[Byte]] = {
...@@ -266,7 +268,7 @@ private[spark] object PythonRDD { ...@@ -266,7 +268,7 @@ private[spark] object PythonRDD {
} }
def writeUTF(str: String, dataOut: DataOutputStream) { def writeUTF(str: String, dataOut: DataOutputStream) {
val bytes = str.getBytes("UTF-8") val bytes = str.getBytes(UTF8)
dataOut.writeInt(bytes.length) dataOut.writeInt(bytes.length)
dataOut.write(bytes) dataOut.write(bytes)
} }
...@@ -286,7 +288,7 @@ private[spark] object PythonRDD { ...@@ -286,7 +288,7 @@ private[spark] object PythonRDD {
private private
class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] { class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") override def call(arr: Array[Byte]) : String = new String(arr, PythonRDD.UTF8)
} }
/** /**
......
...@@ -28,7 +28,8 @@ from pyspark.broadcast import Broadcast ...@@ -28,7 +28,8 @@ from pyspark.broadcast import Broadcast
from pyspark.conf import SparkConf from pyspark.conf import SparkConf
from pyspark.files import SparkFiles from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
PairDeserializer
from pyspark.storagelevel import StorageLevel from pyspark.storagelevel import StorageLevel
from pyspark import rdd from pyspark import rdd
from pyspark.rdd import RDD from pyspark.rdd import RDD
...@@ -257,6 +258,45 @@ class SparkContext(object): ...@@ -257,6 +258,45 @@ class SparkContext(object):
return RDD(self._jsc.textFile(name, minSplits), self, return RDD(self._jsc.textFile(name, minSplits), self,
UTF8Deserializer()) UTF8Deserializer())
def wholeTextFiles(self, path):
"""
Read a directory of text files from HDFS, a local file system
(available on all nodes), or any Hadoop-supported file system
URI. Each file is read as a single record and returned in a
key-value pair, where the key is the path of each file, the
value is the content of each file.
For example, if you have the following files::
hdfs://a-hdfs-path/part-00000
hdfs://a-hdfs-path/part-00001
...
hdfs://a-hdfs-path/part-nnnnn
Do C{rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")},
then C{rdd} contains::
(a-hdfs-path/part-00000, its content)
(a-hdfs-path/part-00001, its content)
...
(a-hdfs-path/part-nnnnn, its content)
NOTE: Small files are preferred, as each file will be loaded
fully in memory.
>>> dirPath = os.path.join(tempdir, "files")
>>> os.mkdir(dirPath)
>>> with open(os.path.join(dirPath, "1.txt"), "w") as file1:
... file1.write("1")
>>> with open(os.path.join(dirPath, "2.txt"), "w") as file2:
... file2.write("2")
>>> textFiles = sc.wholeTextFiles(dirPath)
>>> sorted(textFiles.collect())
[(u'.../1.txt', u'1'), (u'.../2.txt', u'2')]
"""
return RDD(self._jsc.wholeTextFiles(path), self,
PairDeserializer(UTF8Deserializer(), UTF8Deserializer()))
def _checkpointFile(self, name, input_deserializer): def _checkpointFile(self, name, input_deserializer):
jrdd = self._jsc.checkpointFile(name) jrdd = self._jsc.checkpointFile(name)
return RDD(jrdd, self, input_deserializer) return RDD(jrdd, self, input_deserializer)
...@@ -425,7 +465,7 @@ def _test(): ...@@ -425,7 +465,7 @@ def _test():
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
globs['tempdir'] = tempfile.mkdtemp() globs['tempdir'] = tempfile.mkdtemp()
atexit.register(lambda: shutil.rmtree(globs['tempdir'])) atexit.register(lambda: shutil.rmtree(globs['tempdir']))
(failure_count, test_count) = doctest.testmod(globs=globs) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop() globs['sc'].stop()
if failure_count: if failure_count:
exit(-1) exit(-1)
......
...@@ -290,7 +290,7 @@ class MarshalSerializer(FramedSerializer): ...@@ -290,7 +290,7 @@ class MarshalSerializer(FramedSerializer):
class UTF8Deserializer(Serializer): class UTF8Deserializer(Serializer):
""" """
Deserializes streams written by getBytes. Deserializes streams written by String.getBytes.
""" """
def loads(self, stream): def loads(self, stream):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment