diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index cb847a042031127f5b0974446330bc1f4b6c6ca3..f3092918abb544a09440d32c7a35f0a3fcdf53d9 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -335,7 +335,8 @@ class DataFrameReader(OptionUtils): ``inferSchema`` is enabled. To avoid going through the entire data once, disable ``inferSchema`` option or specify the schema explicitly using ``schema``. - :param path: string, or list of strings, for input path(s). + :param path: string, or list of strings, for input path(s), + or RDD of Strings storing CSV rows. :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param sep: sets the single character as a separator for each field and value. @@ -408,6 +409,10 @@ class DataFrameReader(OptionUtils): >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes [('_c0', 'string'), ('_c1', 'string')] + >>> rdd = sc.textFile('python/test_support/sql/ages.csv') + >>> df2 = spark.read.csv(rdd) + >>> df2.dtypes + [('_c0', 'string'), ('_c1', 'string')] """ self._set_opts( schema=schema, sep=sep, encoding=encoding, quote=quote, escape=escape, comment=comment, @@ -420,7 +425,29 @@ class DataFrameReader(OptionUtils): columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine) if isinstance(path, basestring): path = [path] - return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) + if type(path) == list: + return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) + elif isinstance(path, RDD): + def func(iterator): + for x in iterator: + if not isinstance(x, basestring): + x = unicode(x) + if isinstance(x, unicode): + x = x.encode("utf-8") + yield x + keyed = path.mapPartitions(func) + keyed._bypass_serializer = True + jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString()) + # see SPARK-22112 + # There aren't any jvm api for creating a dataframe from rdd storing csv. + # We can do it through creating a jvm dataset firstly and using the jvm api + # for creating a dataframe from dataset storing csv. + jdataset = self._spark._ssql_ctx.createDataset( + jrdd.rdd(), + self._spark._jvm.Encoders.STRING()) + return self._df(self._jreader.csv(jdataset)) + else: + raise TypeError("path can be only string, list or RDD") @since(1.5) def orc(self, path):