diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala index 127c9728da2d1a3d3b67af40c47c7b89da428d5f..676a3d3bca9f7b733babb195354007e47cd16368 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParameters.scala @@ -19,7 +19,10 @@ package org.apache.spark.sql.execution.datasources.csv import java.nio.charset.Charset +import org.apache.hadoop.io.compress._ + import org.apache.spark.Logging +import org.apache.spark.util.Utils private[sql] case class CSVParameters(@transient parameters: Map[String, String]) extends Logging { @@ -35,7 +38,7 @@ private[sql] case class CSVParameters(@transient parameters: Map[String, String] private def getBool(paramName: String, default: Boolean = false): Boolean = { val param = parameters.getOrElse(paramName, default.toString) - if (param.toLowerCase() == "true") { + if (param.toLowerCase == "true") { true } else if (param.toLowerCase == "false") { false @@ -73,6 +76,11 @@ private[sql] case class CSVParameters(@transient parameters: Map[String, String] val nullValue = parameters.getOrElse("nullValue", "") + val compressionCodec: Option[String] = { + val name = parameters.get("compression").orElse(parameters.get("codec")) + name.map(CSVCompressionCodecs.getCodecClassName) + } + val maxColumns = 20480 val maxCharsPerColumn = 100000 @@ -85,7 +93,6 @@ private[sql] case class CSVParameters(@transient parameters: Map[String, String] } private[csv] object ParseModes { - val PERMISSIVE_MODE = "PERMISSIVE" val DROP_MALFORMED_MODE = "DROPMALFORMED" val FAIL_FAST_MODE = "FAILFAST" @@ -107,3 +114,28 @@ private[csv] object ParseModes { true // We default to permissive is the mode string is not valid } } + +private[csv] object CSVCompressionCodecs { + private val shortCompressionCodecNames = Map( + "bzip2" -> classOf[BZip2Codec].getName, + "gzip" -> classOf[GzipCodec].getName, + "lz4" -> classOf[Lz4Codec].getName, + "snappy" -> classOf[SnappyCodec].getName) + + /** + * Return the full version of the given codec class. + * If it is already a class name, just return it. + */ + def getCodecClassName(name: String): String = { + val codecName = shortCompressionCodecNames.getOrElse(name.toLowerCase, name) + try { + // Validate the codec name + Utils.classForName(codecName) + codecName + } catch { + case e: ClassNotFoundException => + throw new IllegalArgumentException(s"Codec [$codecName] " + + s"is not available. Known codecs are ${shortCompressionCodecNames.keys.mkString(", ")}.") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 53818853ffb3befaee405b05526850047fed400b..1502501c3b89e5cc03f1e18cb1982f464795b71b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -24,6 +24,7 @@ import scala.util.control.NonFatal import com.google.common.base.Objects import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{LongWritable, NullWritable, Text} +import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.hadoop.mapreduce.RecordWriter @@ -99,6 +100,15 @@ private[csv] class CSVRelation( } override def prepareJobForWrite(job: Job): OutputWriterFactory = { + val conf = job.getConfiguration + params.compressionCodec.foreach { codec => + conf.set("mapreduce.output.fileoutputformat.compress", "true") + conf.set("mapreduce.output.fileoutputformat.compress.type", CompressionType.BLOCK.toString) + conf.set("mapreduce.output.fileoutputformat.compress.codec", codec) + conf.set("mapreduce.map.output.compress", "true") + conf.set("mapreduce.map.output.compress.codec", codec) + } + new CSVOutputWriterFactory(params) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 071b5ef56d58b29f7245844cc45a6a6857452e67..a79566b1f365813fa76d701b455fbdb8dffeb0d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -349,4 +349,30 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { assert(results(0).toSeq === Array(2012, "Tesla", "S", "null", "null")) assert(results(2).toSeq === Array(null, "Chevy", "Volt", null, null)) } + + test("save csv with compression codec option") { + withTempDir { dir => + val csvDir = new File(dir, "csv").getCanonicalPath + val cars = sqlContext.read + .format("csv") + .option("header", "true") + .load(testFile(carsFile)) + + cars.coalesce(1).write + .format("csv") + .option("header", "true") + .option("compression", "gZiP") + .save(csvDir) + + val compressedFiles = new File(csvDir).listFiles() + assert(compressedFiles.exists(_.getName.endsWith(".gz"))) + + val carsCopy = sqlContext.read + .format("csv") + .option("header", "true") + .load(csvDir) + + verifyCars(carsCopy, withHeader = true) + } + } }