diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index a3691158ee7587e861ca1ff7c51fc6174bcbef29..e627f040d3cc87cf749d38bc5f472aebb8e1b62c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -27,10 +27,12 @@ import org.apache.hadoop.mapreduce._
 
 import org.apache.spark.TaskContext
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.util.CompressionCodecs
 import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.datasources.text.TextFileFormat
+import org.apache.spark.sql.functions.{length, trim}
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types._
 import org.apache.spark.util.SerializableConfiguration
@@ -52,17 +54,21 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
       sparkSession: SparkSession,
       options: Map[String, String],
       files: Seq[FileStatus]): Option[StructType] = {
+    require(files.nonEmpty, "Cannot infer schema from an empty set of files")
     val csvOptions = new CSVOptions(options)
 
     // TODO: Move filtering.
     val paths = files.filterNot(_.getPath.getName startsWith "_").map(_.getPath.toString)
-    val rdd = baseRdd(sparkSession, csvOptions, paths)
-    val firstLine = findFirstLine(csvOptions, rdd)
+    val lines: Dataset[String] = readText(sparkSession, csvOptions, paths)
+    val firstLine: String = findFirstLine(csvOptions, lines)
     val firstRow = new CsvReader(csvOptions).parseLine(firstLine)
     val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
     val header = makeSafeHeader(firstRow, csvOptions, caseSensitive)
 
-    val parsedRdd = tokenRdd(sparkSession, csvOptions, header, paths)
+    val parsedRdd: RDD[Array[String]] = CSVRelation.univocityTokenizer(
+      lines,
+      firstLine = if (csvOptions.headerFlag) firstLine else null,
+      params = csvOptions)
     val schema = if (csvOptions.inferSchemaFlag) {
       CSVInferSchema.infer(parsedRdd, header, csvOptions)
     } else {
@@ -173,51 +179,37 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
     }
   }
 
-  private def baseRdd(
-      sparkSession: SparkSession,
-      options: CSVOptions,
-      inputPaths: Seq[String]): RDD[String] = {
-    readText(sparkSession, options, inputPaths.mkString(","))
-  }
-
-  private def tokenRdd(
-      sparkSession: SparkSession,
-      options: CSVOptions,
-      header: Array[String],
-      inputPaths: Seq[String]): RDD[Array[String]] = {
-    val rdd = baseRdd(sparkSession, options, inputPaths)
-    // Make sure firstLine is materialized before sending to executors
-    val firstLine = if (options.headerFlag) findFirstLine(options, rdd) else null
-    CSVRelation.univocityTokenizer(rdd, firstLine, options)
-  }
-
   /**
    * Returns the first line of the first non-empty file in path
    */
-  private def findFirstLine(options: CSVOptions, rdd: RDD[String]): String = {
+  private def findFirstLine(options: CSVOptions, lines: Dataset[String]): String = {
+    import lines.sqlContext.implicits._
+    val nonEmptyLines = lines.filter(length(trim($"value")) > 0)
     if (options.isCommentSet) {
-      val comment = options.comment.toString
-      rdd.filter { line =>
-        line.trim.nonEmpty && !line.startsWith(comment)
-      }.first()
+      nonEmptyLines.filter(!$"value".startsWith(options.comment.toString)).first()
     } else {
-      rdd.filter { line =>
-        line.trim.nonEmpty
-      }.first()
+      nonEmptyLines.first()
     }
   }
 
   private def readText(
       sparkSession: SparkSession,
       options: CSVOptions,
-      location: String): RDD[String] = {
+      inputPaths: Seq[String]): Dataset[String] = {
     if (Charset.forName(options.charset) == StandardCharsets.UTF_8) {
-      sparkSession.sparkContext.textFile(location)
+      sparkSession.baseRelationToDataFrame(
+        DataSource.apply(
+          sparkSession,
+          paths = inputPaths,
+          className = classOf[TextFileFormat].getName
+        ).resolveRelation(checkFilesExist = false))
+        .select("value").as[String](Encoders.STRING)
     } else {
       val charset = options.charset
-      sparkSession.sparkContext
-        .hadoopFile[LongWritable, Text, TextInputFormat](location)
+      val rdd = sparkSession.sparkContext
+        .hadoopFile[LongWritable, Text, TextInputFormat](inputPaths.mkString(","))
         .mapPartitions(_.map(pair => new String(pair._2.getBytes, 0, pair._2.getLength, charset)))
+      sparkSession.createDataset(rdd)(Encoders.STRING)
     }
   }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
index 52de11d403446efeb4f3b9607ce118c74a02b062..e4ce7a94be7df1874534fc6e02a3612314bdbe37 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
@@ -34,12 +34,12 @@ import org.apache.spark.sql.types._
 object CSVRelation extends Logging {
 
   def univocityTokenizer(
-      file: RDD[String],
+      file: Dataset[String],
       firstLine: String,
       params: CSVOptions): RDD[Array[String]] = {
     // If header is set, make sure firstLine is materialized before sending to executors.
     val commentPrefix = params.comment.toString
-    file.mapPartitions { iter =>
+    file.rdd.mapPartitions { iter =>
       val parser = new CsvReader(params)
       val filteredIter = iter.filter { line =>
         line.trim.nonEmpty && !line.startsWith(commentPrefix)