From 60e18ce7dd1016647b63586520b713efc45494a8 Mon Sep 17 00:00:00 2001
From: Matei Zaharia <matei@databricks.com>
Date: Fri, 4 Apr 2014 17:29:29 -0700
Subject: [PATCH] SPARK-1414. Python API for SparkContext.wholeTextFiles

Also clarified comment on each file having to fit in memory

Author: Matei Zaharia <matei@databricks.com>

Closes #327 from mateiz/py-whole-files and squashes the following commits:

9ad64a5 [Matei Zaharia] SPARK-1414. Python API for SparkContext.wholeTextFiles
---
 .../scala/org/apache/spark/SparkContext.scala |  2 +-
 .../spark/api/java/JavaSparkContext.scala     |  2 +-
 .../apache/spark/api/python/PythonRDD.scala   |  6 ++-
 python/pyspark/context.py                     | 44 ++++++++++++++++++-
 python/pyspark/serializers.py                 |  2 +-
 5 files changed, 49 insertions(+), 7 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 28a865c0ad..835cffe37a 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -395,7 +395,7 @@ class SparkContext(
    *   (a-hdfs-path/part-nnnnn, its content)
    * }}}
    *
-   * @note Small files are perferred, large file is also allowable, but may cause bad performance.
+   * @note Small files are preferred, as each file will be loaded fully in memory.
    */
   def wholeTextFiles(path: String): RDD[(String, String)] = {
     newAPIHadoopFile(
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index 6cbdeac58d..a2855d4db1 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -177,7 +177,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
    *   (a-hdfs-path/part-nnnnn, its content)
    * }}}
    *
-   * @note Small files are perferred, large file is also allowable, but may cause bad performance.
+   * @note Small files are preferred, as each file will be loaded fully in memory.
    */
   def wholeTextFiles(path: String): JavaPairRDD[String, String] =
     new JavaPairRDD(sc.wholeTextFiles(path))
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index b67286a4e3..32f1100406 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -19,6 +19,7 @@ package org.apache.spark.api.python
 
 import java.io._
 import java.net._
+import java.nio.charset.Charset
 import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
 
 import scala.collection.JavaConversions._
@@ -206,6 +207,7 @@ private object SpecialLengths {
 }
 
 private[spark] object PythonRDD {
+  val UTF8 = Charset.forName("UTF-8")
 
   def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
   JavaRDD[Array[Byte]] = {
@@ -266,7 +268,7 @@ private[spark] object PythonRDD {
   }
 
   def writeUTF(str: String, dataOut: DataOutputStream) {
-    val bytes = str.getBytes("UTF-8")
+    val bytes = str.getBytes(UTF8)
     dataOut.writeInt(bytes.length)
     dataOut.write(bytes)
   }
@@ -286,7 +288,7 @@ private[spark] object PythonRDD {
 
 private
 class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
-  override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
+  override def call(arr: Array[Byte]) : String = new String(arr, PythonRDD.UTF8)
 }
 
 /**
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index bf2454fd7e..ff1023bbfa 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -28,7 +28,8 @@ from pyspark.broadcast import Broadcast
 from pyspark.conf import SparkConf
 from pyspark.files import SparkFiles
 from pyspark.java_gateway import launch_gateway
-from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer
+from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
+        PairDeserializer
 from pyspark.storagelevel import StorageLevel
 from pyspark import rdd
 from pyspark.rdd import RDD
@@ -257,6 +258,45 @@ class SparkContext(object):
         return RDD(self._jsc.textFile(name, minSplits), self,
                    UTF8Deserializer())
 
+    def wholeTextFiles(self, path):
+        """
+        Read a directory of text files from HDFS, a local file system
+        (available on all nodes), or any  Hadoop-supported file system
+        URI. Each file is read as a single record and returned in a
+        key-value pair, where the key is the path of each file, the
+        value is the content of each file.
+
+        For example, if you have the following files::
+
+          hdfs://a-hdfs-path/part-00000
+          hdfs://a-hdfs-path/part-00001
+          ...
+          hdfs://a-hdfs-path/part-nnnnn
+
+        Do C{rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")},
+        then C{rdd} contains::
+
+          (a-hdfs-path/part-00000, its content)
+          (a-hdfs-path/part-00001, its content)
+          ...
+          (a-hdfs-path/part-nnnnn, its content)
+
+        NOTE: Small files are preferred, as each file will be loaded
+        fully in memory.
+
+        >>> dirPath = os.path.join(tempdir, "files")
+        >>> os.mkdir(dirPath)
+        >>> with open(os.path.join(dirPath, "1.txt"), "w") as file1:
+        ...    file1.write("1")
+        >>> with open(os.path.join(dirPath, "2.txt"), "w") as file2:
+        ...    file2.write("2")
+        >>> textFiles = sc.wholeTextFiles(dirPath)
+        >>> sorted(textFiles.collect())
+        [(u'.../1.txt', u'1'), (u'.../2.txt', u'2')]
+        """
+        return RDD(self._jsc.wholeTextFiles(path), self,
+                   PairDeserializer(UTF8Deserializer(), UTF8Deserializer()))
+
     def _checkpointFile(self, name, input_deserializer):
         jrdd = self._jsc.checkpointFile(name)
         return RDD(jrdd, self, input_deserializer)
@@ -425,7 +465,7 @@ def _test():
     globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
     globs['tempdir'] = tempfile.mkdtemp()
     atexit.register(lambda: shutil.rmtree(globs['tempdir']))
-    (failure_count, test_count) = doctest.testmod(globs=globs)
+    (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
     globs['sc'].stop()
     if failure_count:
         exit(-1)
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 12c63f186a..4d802924df 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -290,7 +290,7 @@ class MarshalSerializer(FramedSerializer):
 
 class UTF8Deserializer(Serializer):
     """
-    Deserializes streams written by getBytes.
+    Deserializes streams written by String.getBytes.
     """
 
     def loads(self, stream):
-- 
GitLab