diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f763106da4e0e84c2a722ad841509bec7a558dba..394a59700dbafe109bf40741d409bb777456db47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -140,12 +140,35 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { InsertIntoParquetTable(relation, planLater(child), overwrite=true)(sparkContext) :: Nil case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => InsertIntoParquetTable(table, planLater(child), overwrite)(sparkContext) :: Nil - case PhysicalOperation(projectList, filters, relation: ParquetRelation) => - // TODO: Should be pushing down filters as well. + case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => { + val remainingFilters = + if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) { + filters.filter { + // Note: filters cannot be pushed down to Parquet if they contain more complex + // expressions than simple "Attribute cmp Literal" comparisons. Here we remove + // all filters that have been pushed down. Note that a predicate such as + // "(A AND B) OR C" can result in "A OR C" being pushed down. + filter => + val recordFilter = ParquetFilters.createFilter(filter) + if (!recordFilter.isDefined) { + // First case: the pushdown did not result in any record filter. + true + } else { + // Second case: a record filter was created; here we are conservative in + // the sense that even if "A" was pushed and we check for "A AND B" we + // still want to keep "A AND B" in the higher-level filter, not just "B". + !ParquetFilters.findExpression(recordFilter.get, filter).isDefined + } + } + } else { + filters + } pruneFilterProject( projectList, - filters, - ParquetTableScan(_, relation, None)(sparkContext)) :: Nil + remainingFilters, + ParquetTableScan(_, relation, filters)(sparkContext)) :: Nil + } + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala new file mode 100644 index 0000000000000000000000000000000000000000..052b0a9196717489335d74d9155b88bdf5556da7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -0,0 +1,436 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import org.apache.hadoop.conf.Configuration + +import parquet.filter._ +import parquet.filter.ColumnPredicates._ +import parquet.column.ColumnReader + +import com.google.common.io.BaseEncoding + +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.expressions.{Predicate => CatalystPredicate} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkSqlSerializer + +object ParquetFilters { + val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" + // set this to false if pushdown should be disabled + val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.hints.parquetFilterPushdown" + + def createRecordFilter(filterExpressions: Seq[Expression]): UnboundRecordFilter = { + val filters: Seq[CatalystFilter] = filterExpressions.collect { + case (expression: Expression) if createFilter(expression).isDefined => + createFilter(expression).get + } + if (filters.length > 0) filters.reduce(AndRecordFilter.and) else null + } + + def createFilter(expression: Expression): Option[CatalystFilter] = { + def createEqualityFilter( + name: String, + literal: Literal, + predicate: CatalystPredicate) = literal.dataType match { + case BooleanType => + ComparisonFilter.createBooleanFilter(name, literal.value.asInstanceOf[Boolean], predicate) + case IntegerType => + ComparisonFilter.createIntFilter( + name, + (x: Int) => x == literal.value.asInstanceOf[Int], + predicate) + case LongType => + ComparisonFilter.createLongFilter( + name, + (x: Long) => x == literal.value.asInstanceOf[Long], + predicate) + case DoubleType => + ComparisonFilter.createDoubleFilter( + name, + (x: Double) => x == literal.value.asInstanceOf[Double], + predicate) + case FloatType => + ComparisonFilter.createFloatFilter( + name, + (x: Float) => x == literal.value.asInstanceOf[Float], + predicate) + case StringType => + ComparisonFilter.createStringFilter(name, literal.value.asInstanceOf[String], predicate) + } + def createLessThanFilter( + name: String, + literal: Literal, + predicate: CatalystPredicate) = literal.dataType match { + case IntegerType => + ComparisonFilter.createIntFilter( + name, + (x: Int) => x < literal.value.asInstanceOf[Int], + predicate) + case LongType => + ComparisonFilter.createLongFilter( + name, + (x: Long) => x < literal.value.asInstanceOf[Long], + predicate) + case DoubleType => + ComparisonFilter.createDoubleFilter( + name, + (x: Double) => x < literal.value.asInstanceOf[Double], + predicate) + case FloatType => + ComparisonFilter.createFloatFilter( + name, + (x: Float) => x < literal.value.asInstanceOf[Float], + predicate) + } + def createLessThanOrEqualFilter( + name: String, + literal: Literal, + predicate: CatalystPredicate) = literal.dataType match { + case IntegerType => + ComparisonFilter.createIntFilter( + name, + (x: Int) => x <= literal.value.asInstanceOf[Int], + predicate) + case LongType => + ComparisonFilter.createLongFilter( + name, + (x: Long) => x <= literal.value.asInstanceOf[Long], + predicate) + case DoubleType => + ComparisonFilter.createDoubleFilter( + name, + (x: Double) => x <= literal.value.asInstanceOf[Double], + predicate) + case FloatType => + ComparisonFilter.createFloatFilter( + name, + (x: Float) => x <= literal.value.asInstanceOf[Float], + predicate) + } + // TODO: combine these two types somehow? + def createGreaterThanFilter( + name: String, + literal: Literal, + predicate: CatalystPredicate) = literal.dataType match { + case IntegerType => + ComparisonFilter.createIntFilter( + name, + (x: Int) => x > literal.value.asInstanceOf[Int], + predicate) + case LongType => + ComparisonFilter.createLongFilter( + name, + (x: Long) => x > literal.value.asInstanceOf[Long], + predicate) + case DoubleType => + ComparisonFilter.createDoubleFilter( + name, + (x: Double) => x > literal.value.asInstanceOf[Double], + predicate) + case FloatType => + ComparisonFilter.createFloatFilter( + name, + (x: Float) => x > literal.value.asInstanceOf[Float], + predicate) + } + def createGreaterThanOrEqualFilter( + name: String, + literal: Literal, + predicate: CatalystPredicate) = literal.dataType match { + case IntegerType => + ComparisonFilter.createIntFilter( + name, (x: Int) => x >= literal.value.asInstanceOf[Int], + predicate) + case LongType => + ComparisonFilter.createLongFilter( + name, + (x: Long) => x >= literal.value.asInstanceOf[Long], + predicate) + case DoubleType => + ComparisonFilter.createDoubleFilter( + name, + (x: Double) => x >= literal.value.asInstanceOf[Double], + predicate) + case FloatType => + ComparisonFilter.createFloatFilter( + name, + (x: Float) => x >= literal.value.asInstanceOf[Float], + predicate) + } + + /** + * TODO: we currently only filter on non-nullable (Parquet REQUIRED) attributes until + * https://github.com/Parquet/parquet-mr/issues/371 + * has been resolved. + */ + expression match { + case p @ Or(left: Expression, right: Expression) + if createFilter(left).isDefined && createFilter(right).isDefined => { + // If either side of this Or-predicate is empty then this means + // it contains a more complex comparison than between attribute and literal + // (e.g., it contained a CAST). The only safe thing to do is then to disregard + // this disjunction, which could be contained in a conjunction. If it stands + // alone then it is also safe to drop it, since a Null return value of this + // function is interpreted as having no filters at all. + val leftFilter = createFilter(left).get + val rightFilter = createFilter(right).get + Some(new OrFilter(leftFilter, rightFilter)) + } + case p @ And(left: Expression, right: Expression) => { + // This treats nested conjunctions; since either side of the conjunction + // may contain more complex filter expressions we may actually generate + // strictly weaker filter predicates in the process. + val leftFilter = createFilter(left) + val rightFilter = createFilter(right) + (leftFilter, rightFilter) match { + case (None, Some(filter)) => Some(filter) + case (Some(filter), None) => Some(filter) + case (_, _) => + Some(new AndFilter(leftFilter.get, rightFilter.get)) + } + } + case p @ Equals(left: Literal, right: NamedExpression) if !right.nullable => + Some(createEqualityFilter(right.name, left, p)) + case p @ Equals(left: NamedExpression, right: Literal) if !left.nullable => + Some(createEqualityFilter(left.name, right, p)) + case p @ LessThan(left: Literal, right: NamedExpression) if !right.nullable => + Some(createLessThanFilter(right.name, left, p)) + case p @ LessThan(left: NamedExpression, right: Literal) if !left.nullable => + Some(createLessThanFilter(left.name, right, p)) + case p @ LessThanOrEqual(left: Literal, right: NamedExpression) if !right.nullable => + Some(createLessThanOrEqualFilter(right.name, left, p)) + case p @ LessThanOrEqual(left: NamedExpression, right: Literal) if !left.nullable => + Some(createLessThanOrEqualFilter(left.name, right, p)) + case p @ GreaterThan(left: Literal, right: NamedExpression) if !right.nullable => + Some(createGreaterThanFilter(right.name, left, p)) + case p @ GreaterThan(left: NamedExpression, right: Literal) if !left.nullable => + Some(createGreaterThanFilter(left.name, right, p)) + case p @ GreaterThanOrEqual(left: Literal, right: NamedExpression) if !right.nullable => + Some(createGreaterThanOrEqualFilter(right.name, left, p)) + case p @ GreaterThanOrEqual(left: NamedExpression, right: Literal) if !left.nullable => + Some(createGreaterThanOrEqualFilter(left.name, right, p)) + case _ => None + } + } + + /** + * Note: Inside the Hadoop API we only have access to `Configuration`, not to + * [[org.apache.spark.SparkContext]], so we cannot use broadcasts to convey + * the actual filter predicate. + */ + def serializeFilterExpressions(filters: Seq[Expression], conf: Configuration): Unit = { + if (filters.length > 0) { + val serialized: Array[Byte] = SparkSqlSerializer.serialize(filters) + val encoded: String = BaseEncoding.base64().encode(serialized) + conf.set(PARQUET_FILTER_DATA, encoded) + } + } + + /** + * Note: Inside the Hadoop API we only have access to `Configuration`, not to + * [[org.apache.spark.SparkContext]], so we cannot use broadcasts to convey + * the actual filter predicate. + */ + def deserializeFilterExpressions(conf: Configuration): Seq[Expression] = { + val data = conf.get(PARQUET_FILTER_DATA) + if (data != null) { + val decoded: Array[Byte] = BaseEncoding.base64().decode(data) + SparkSqlSerializer.deserialize(decoded) + } else { + Seq() + } + } + + /** + * Try to find the given expression in the tree of filters in order to + * determine whether it is safe to remove it from the higher level filters. Note + * that strictly speaking we could stop the search whenever an expression is found + * that contains this expression as subexpression (e.g., when searching for "a" + * and "(a or c)" is found) but we don't care about optimizations here since the + * filter tree is assumed to be small. + * + * @param filter The [[org.apache.spark.sql.parquet.CatalystFilter]] to expand + * and search + * @param expression The expression to look for + * @return An optional [[org.apache.spark.sql.parquet.CatalystFilter]] that + * contains the expression. + */ + def findExpression( + filter: CatalystFilter, + expression: Expression): Option[CatalystFilter] = filter match { + case f @ OrFilter(_, leftFilter, rightFilter, _) => + if (f.predicate == expression) { + Some(f) + } else { + val left = findExpression(leftFilter, expression) + if (left.isDefined) left else findExpression(rightFilter, expression) + } + case f @ AndFilter(_, leftFilter, rightFilter, _) => + if (f.predicate == expression) { + Some(f) + } else { + val left = findExpression(leftFilter, expression) + if (left.isDefined) left else findExpression(rightFilter, expression) + } + case f @ ComparisonFilter(_, _, predicate) => + if (predicate == expression) Some(f) else None + case _ => None + } +} + +abstract private[parquet] class CatalystFilter( + @transient val predicate: CatalystPredicate) extends UnboundRecordFilter + +private[parquet] case class ComparisonFilter( + val columnName: String, + private var filter: UnboundRecordFilter, + @transient override val predicate: CatalystPredicate) + extends CatalystFilter(predicate) { + override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = { + filter.bind(readers) + } +} + +private[parquet] case class OrFilter( + private var filter: UnboundRecordFilter, + @transient val left: CatalystFilter, + @transient val right: CatalystFilter, + @transient override val predicate: Or) + extends CatalystFilter(predicate) { + def this(l: CatalystFilter, r: CatalystFilter) = + this( + OrRecordFilter.or(l, r), + l, + r, + Or(l.predicate, r.predicate)) + + override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = { + filter.bind(readers) + } +} + +private[parquet] case class AndFilter( + private var filter: UnboundRecordFilter, + @transient val left: CatalystFilter, + @transient val right: CatalystFilter, + @transient override val predicate: And) + extends CatalystFilter(predicate) { + def this(l: CatalystFilter, r: CatalystFilter) = + this( + AndRecordFilter.and(l, r), + l, + r, + And(l.predicate, r.predicate)) + + override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = { + filter.bind(readers) + } +} + +private[parquet] object ComparisonFilter { + def createBooleanFilter( + columnName: String, + value: Boolean, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + ColumnRecordFilter.column( + columnName, + ColumnPredicates.applyFunctionToBoolean( + new BooleanPredicateFunction { + def functionToApply(input: Boolean): Boolean = input == value + } + )), + predicate) + + def createStringFilter( + columnName: String, + value: String, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + ColumnRecordFilter.column( + columnName, + ColumnPredicates.applyFunctionToString ( + new ColumnPredicates.PredicateFunction[String] { + def functionToApply(input: String): Boolean = input == value + } + )), + predicate) + + def createIntFilter( + columnName: String, + func: Int => Boolean, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + ColumnRecordFilter.column( + columnName, + ColumnPredicates.applyFunctionToInteger( + new IntegerPredicateFunction { + def functionToApply(input: Int) = func(input) + } + )), + predicate) + + def createLongFilter( + columnName: String, + func: Long => Boolean, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + ColumnRecordFilter.column( + columnName, + ColumnPredicates.applyFunctionToLong( + new LongPredicateFunction { + def functionToApply(input: Long) = func(input) + } + )), + predicate) + + def createDoubleFilter( + columnName: String, + func: Double => Boolean, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + ColumnRecordFilter.column( + columnName, + ColumnPredicates.applyFunctionToDouble( + new DoublePredicateFunction { + def functionToApply(input: Double) = func(input) + } + )), + predicate) + + def createFloatFilter( + columnName: String, + func: Float => Boolean, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + ColumnRecordFilter.column( + columnName, + ColumnPredicates.applyFunctionToFloat( + new FloatPredicateFunction { + def functionToApply(input: Float) = func(input) + } + )), + predicate) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index f825ca3c028efa3503e67686388de0b2dea82b73..65ba1246fbf9aa042c2c6f007709571204d46c93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -27,26 +27,27 @@ import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat, FileOutputCommitter} -import parquet.hadoop.{ParquetInputFormat, ParquetOutputFormat} +import parquet.hadoop.{ParquetRecordReader, ParquetInputFormat, ParquetOutputFormat} +import parquet.hadoop.api.ReadSupport import parquet.hadoop.util.ContextUtil import parquet.io.InvalidRecordException import parquet.schema.MessageType -import org.apache.spark.{SerializableWritable, SparkContext, TaskContext} +import org.apache.spark.{Logging, SerializableWritable, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} /** * Parquet table scan operator. Imports the file that backs the given - * [[ParquetRelation]] as a RDD[Row]. + * [[org.apache.spark.sql.parquet.ParquetRelation]] as a ``RDD[Row]``. */ case class ParquetTableScan( // note: output cannot be transient, see // https://issues.apache.org/jira/browse/SPARK-1367 output: Seq[Attribute], relation: ParquetRelation, - columnPruningPred: Option[Expression])( + columnPruningPred: Seq[Expression])( @transient val sc: SparkContext) extends LeafNode { @@ -62,18 +63,30 @@ case class ParquetTableScan( for (path <- fileList if !path.getName.startsWith("_")) { NewFileInputFormat.addInputPath(job, path) } + + // Store Parquet schema in `Configuration` conf.set( RowReadSupport.PARQUET_ROW_REQUESTED_SCHEMA, ParquetTypesConverter.convertFromAttributes(output).toString) - // TODO: think about adding record filters - /* Comments regarding record filters: it would be nice to push down as much filtering - to Parquet as possible. However, currently it seems we cannot pass enough information - to materialize an (arbitrary) Catalyst [[Predicate]] inside Parquet's - ``FilteredRecordReader`` (via Configuration, for example). Simple - filter-rows-by-column-values however should be supported. - */ - sc.newAPIHadoopRDD(conf, classOf[ParquetInputFormat[Row]], classOf[Void], classOf[Row]) - .map(_._2) + + // Store record filtering predicate in `Configuration` + // Note 1: the input format ignores all predicates that cannot be expressed + // as simple column predicate filters in Parquet. Here we just record + // the whole pruning predicate. + // Note 2: you can disable filter predicate pushdown by setting + // "spark.sql.hints.parquetFilterPushdown" to false inside SparkConf. + if (columnPruningPred.length > 0 && + sc.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) { + ParquetFilters.serializeFilterExpressions(columnPruningPred, conf) + } + + sc.newAPIHadoopRDD( + conf, + classOf[org.apache.spark.sql.parquet.FilteringParquetRowInputFormat], + classOf[Void], + classOf[Row]) + .map(_._2) + .filter(_ != null) // Parquet's record filters may produce null values } override def otherCopyArgs = sc :: Nil @@ -184,10 +197,19 @@ case class InsertIntoParquetTable( override def otherCopyArgs = sc :: Nil - // based on ``saveAsNewAPIHadoopFile`` in [[PairRDDFunctions]] - // TODO: Maybe PairRDDFunctions should use Product2 instead of Tuple2? - // .. then we could use the default one and could use [[MutablePair]] - // instead of ``Tuple2`` + /** + * Stores the given Row RDD as a Hadoop file. + * + * Note: We cannot use ``saveAsNewAPIHadoopFile`` from [[org.apache.spark.rdd.PairRDDFunctions]] + * together with [[org.apache.spark.util.MutablePair]] because ``PairRDDFunctions`` uses + * ``Tuple2`` and not ``Product2``. Also, we want to allow appending files to an existing + * directory and need to determine which was the largest written file index before starting to + * write. + * + * @param rdd The [[org.apache.spark.rdd.RDD]] to writer + * @param path The directory to write to. + * @param conf A [[org.apache.hadoop.conf.Configuration]]. + */ private def saveAsHadoopFile( rdd: RDD[Row], path: String, @@ -244,8 +266,10 @@ case class InsertIntoParquetTable( } } -// TODO: this will be able to append to directories it created itself, not necessarily -// to imported ones +/** + * TODO: this will be able to append to directories it created itself, not necessarily + * to imported ones. + */ private[parquet] class AppendingParquetOutputFormat(offset: Int) extends parquet.hadoop.ParquetOutputFormat[Row] { // override to accept existing directories as valid output directory @@ -262,6 +286,30 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int) } } +/** + * We extend ParquetInputFormat in order to have more control over which + * RecordFilter we want to use. + */ +private[parquet] class FilteringParquetRowInputFormat + extends parquet.hadoop.ParquetInputFormat[Row] with Logging { + override def createRecordReader( + inputSplit: InputSplit, + taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = { + val readSupport: ReadSupport[Row] = new RowReadSupport() + + val filterExpressions = + ParquetFilters.deserializeFilterExpressions(ContextUtil.getConfiguration(taskAttemptContext)) + if (filterExpressions.length > 0) { + logInfo(s"Pushing down predicates for RecordFilter: ${filterExpressions.mkString(", ")}") + new ParquetRecordReader[Row]( + readSupport, + ParquetFilters.createRecordFilter(filterExpressions)) + } else { + new ParquetRecordReader[Row](readSupport) + } + } +} + private[parquet] object FileSystemHelper { def listFiles(pathStr: String, conf: Configuration): Seq[Path] = { val origPath = new Path(pathStr) @@ -278,7 +326,9 @@ private[parquet] object FileSystemHelper { fs.listStatus(path).map(_.getPath) } - // finds the maximum taskid in the output file names at the given path + /** + * Finds the maximum taskid in the output file names at the given path. + */ def findMaxTaskId(pathStr: String, conf: Configuration): Int = { val files = FileSystemHelper.listFiles(pathStr, conf) // filename pattern is part-r-<int>.parquet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala index f37976f7313c1f0064e62641da2031e9a31a7ed3..46c717298564204af9e445dd795d7b3f93f15f0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -19,15 +19,34 @@ package org.apache.spark.sql.parquet import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce.Job +import parquet.example.data.{GroupWriter, Group} +import parquet.example.data.simple.SimpleGroup import parquet.hadoop.ParquetWriter -import parquet.hadoop.util.ContextUtil +import parquet.hadoop.api.WriteSupport +import parquet.hadoop.api.WriteSupport.WriteContext +import parquet.io.api.RecordConsumer import parquet.schema.{MessageType, MessageTypeParser} -import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.util.Utils +// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport +// with an empty configuration (it is after all not intended to be used in this way?) +// and members are private so we need to make our own in order to pass the schema +// to the writer. +private class TestGroupWriteSupport(schema: MessageType) extends WriteSupport[Group] { + var groupWriter: GroupWriter = null + override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { + groupWriter = new GroupWriter(recordConsumer, schema) + } + override def init(configuration: Configuration): WriteContext = { + new WriteContext(schema, new java.util.HashMap[String, String]()) + } + override def write(record: Group) { + groupWriter.write(record) + } +} + private[sql] object ParquetTestData { val testSchema = @@ -43,7 +62,7 @@ private[sql] object ParquetTestData { // field names for test assertion error messages val testSchemaFieldNames = Seq( "myboolean:Boolean", - "mtint:Int", + "myint:Int", "mystring:String", "mylong:Long", "myfloat:Float", @@ -58,6 +77,18 @@ private[sql] object ParquetTestData { |} """.stripMargin + val testFilterSchema = + """ + |message myrecord { + |required boolean myboolean; + |required int32 myint; + |required binary mystring; + |required int64 mylong; + |required float myfloat; + |required double mydouble; + |} + """.stripMargin + // field names for test assertion error messages val subTestSchemaFieldNames = Seq( "myboolean:Boolean", @@ -65,36 +96,57 @@ private[sql] object ParquetTestData { ) val testDir = Utils.createTempDir() + val testFilterDir = Utils.createTempDir() lazy val testData = new ParquetRelation(testDir.toURI.toString) def writeFile() = { testDir.delete val path: Path = new Path(new Path(testDir.toURI), new Path("part-r-0.parquet")) - val job = new Job() - val configuration: Configuration = ContextUtil.getConfiguration(job) val schema: MessageType = MessageTypeParser.parseMessageType(testSchema) + val writeSupport = new TestGroupWriteSupport(schema) + val writer = new ParquetWriter[Group](path, writeSupport) - val writeSupport = new RowWriteSupport() - writeSupport.setSchema(schema, configuration) - val writer = new ParquetWriter(path, writeSupport) for(i <- 0 until 15) { - val data = new Array[Any](6) + val record = new SimpleGroup(schema) if (i % 3 == 0) { - data.update(0, true) + record.add(0, true) } else { - data.update(0, false) + record.add(0, false) } if (i % 5 == 0) { - data.update(1, 5) + record.add(1, 5) + } + record.add(2, "abc") + record.add(3, i.toLong << 33) + record.add(4, 2.5F) + record.add(5, 4.5D) + writer.write(record) + } + writer.close() + } + + def writeFilterFile(records: Int = 200) = { + // for microbenchmark use: records = 300000000 + testFilterDir.delete + val path: Path = new Path(new Path(testFilterDir.toURI), new Path("part-r-0.parquet")) + val schema: MessageType = MessageTypeParser.parseMessageType(testFilterSchema) + val writeSupport = new TestGroupWriteSupport(schema) + val writer = new ParquetWriter[Group](path, writeSupport) + + for(i <- 0 to records) { + val record = new SimpleGroup(schema) + if (i % 4 == 0) { + record.add(0, true) } else { - data.update(1, null) // optional + record.add(0, false) } - data.update(2, "abc") - data.update(3, i.toLong << 33) - data.update(4, 2.5F) - data.update(5, 4.5D) - writer.write(new GenericRow(data.toArray)) + record.add(1, i) + record.add(2, i.toString) + record.add(3, i.toLong) + record.add(4, i.toFloat + 0.5f) + record.add(5, i.toDouble + 0.5d) + writer.write(record) } writer.close() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index ff1677eb8a48038d170b92a658cffc1ce6d515d4..65f4c17aeee3a44390ccc1a5fc8202ef5e9d5887 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,25 +17,25 @@ package org.apache.spark.sql.parquet -import java.io.File - import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.hadoop.fs.{Path, FileSystem} import org.apache.hadoop.mapreduce.Job import parquet.hadoop.ParquetFileWriter -import parquet.schema.MessageTypeParser import parquet.hadoop.util.ContextUtil +import parquet.schema.MessageTypeParser import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util.getTempFilePath -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.TestData +import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.expressions.Equals +import org.apache.spark.sql.catalyst.types.IntegerType import org.apache.spark.util.Utils -import org.apache.spark.sql.catalyst.types.{StringType, IntegerType, DataType} -import org.apache.spark.sql.{parquet, SchemaRDD} // Implicits import org.apache.spark.sql.test.TestSQLContext._ @@ -64,12 +64,16 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll { override def beforeAll() { ParquetTestData.writeFile() + ParquetTestData.writeFilterFile() testRDD = parquetFile(ParquetTestData.testDir.toString) testRDD.registerAsTable("testsource") + parquetFile(ParquetTestData.testFilterDir.toString) + .registerAsTable("testfiltersource") } override def afterAll() { Utils.deleteRecursively(ParquetTestData.testDir) + Utils.deleteRecursively(ParquetTestData.testFilterDir) // here we should also unregister the table?? } @@ -120,7 +124,7 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll { val scanner = new ParquetTableScan( ParquetTestData.testData.output, ParquetTestData.testData, - None)(TestSQLContext.sparkContext) + Seq())(TestSQLContext.sparkContext) val projected = scanner.pruneColumns(ParquetTypesConverter .convertToAttributes(MessageTypeParser .parseMessageType(ParquetTestData.subTestSchema))) @@ -196,7 +200,6 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll { assert(true) } - test("insert (appending) to same table via Scala API") { sql("INSERT INTO testsource SELECT * FROM testsource").collect() val double_rdd = sql("SELECT * FROM testsource").collect() @@ -239,5 +242,121 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll { Utils.deleteRecursively(file) assert(true) } + + test("create RecordFilter for simple predicates") { + val attribute1 = new AttributeReference("first", IntegerType, false)() + val predicate1 = new Equals(attribute1, new Literal(1, IntegerType)) + val filter1 = ParquetFilters.createFilter(predicate1) + assert(filter1.isDefined) + assert(filter1.get.predicate == predicate1, "predicates do not match") + assert(filter1.get.isInstanceOf[ComparisonFilter]) + val cmpFilter1 = filter1.get.asInstanceOf[ComparisonFilter] + assert(cmpFilter1.columnName == "first", "column name incorrect") + + val predicate2 = new LessThan(attribute1, new Literal(4, IntegerType)) + val filter2 = ParquetFilters.createFilter(predicate2) + assert(filter2.isDefined) + assert(filter2.get.predicate == predicate2, "predicates do not match") + assert(filter2.get.isInstanceOf[ComparisonFilter]) + val cmpFilter2 = filter2.get.asInstanceOf[ComparisonFilter] + assert(cmpFilter2.columnName == "first", "column name incorrect") + + val predicate3 = new And(predicate1, predicate2) + val filter3 = ParquetFilters.createFilter(predicate3) + assert(filter3.isDefined) + assert(filter3.get.predicate == predicate3, "predicates do not match") + assert(filter3.get.isInstanceOf[AndFilter]) + + val predicate4 = new Or(predicate1, predicate2) + val filter4 = ParquetFilters.createFilter(predicate4) + assert(filter4.isDefined) + assert(filter4.get.predicate == predicate4, "predicates do not match") + assert(filter4.get.isInstanceOf[OrFilter]) + + val attribute2 = new AttributeReference("second", IntegerType, false)() + val predicate5 = new GreaterThan(attribute1, attribute2) + val badfilter = ParquetFilters.createFilter(predicate5) + assert(badfilter.isDefined === false) + } + + test("test filter by predicate pushdown") { + for(myval <- Seq("myint", "mylong", "mydouble", "myfloat")) { + println(s"testing field $myval") + val query1 = sql(s"SELECT * FROM testfiltersource WHERE $myval < 150 AND $myval >= 100") + assert( + query1.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result1 = query1.collect() + assert(result1.size === 50) + assert(result1(0)(1) === 100) + assert(result1(49)(1) === 149) + val query2 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200") + assert( + query2.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result2 = query2.collect() + assert(result2.size === 50) + if (myval == "myint" || myval == "mylong") { + assert(result2(0)(1) === 151) + assert(result2(49)(1) === 200) + } else { + assert(result2(0)(1) === 150) + assert(result2(49)(1) === 199) + } + } + for(myval <- Seq("myint", "mylong")) { + val query3 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 190 OR $myval < 10") + assert( + query3.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result3 = query3.collect() + assert(result3.size === 20) + assert(result3(0)(1) === 0) + assert(result3(9)(1) === 9) + assert(result3(10)(1) === 191) + assert(result3(19)(1) === 200) + } + for(myval <- Seq("mydouble", "myfloat")) { + val result4 = + if (myval == "mydouble") { + val query4 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 190.5 OR $myval < 10.0") + assert( + query4.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + query4.collect() + } else { + // CASTs are problematic. Here myfloat will be casted to a double and it seems there is + // currently no way to specify float constants in SqlParser? + sql(s"SELECT * FROM testfiltersource WHERE $myval > 190.5 OR $myval < 10").collect() + } + assert(result4.size === 20) + assert(result4(0)(1) === 0) + assert(result4(9)(1) === 9) + assert(result4(10)(1) === 191) + assert(result4(19)(1) === 200) + } + val query5 = sql(s"SELECT * FROM testfiltersource WHERE myboolean = true AND myint < 40") + assert( + query5.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val booleanResult = query5.collect() + assert(booleanResult.size === 10) + for(i <- 0 until 10) { + if (!booleanResult(i).getBoolean(0)) { + fail(s"Boolean value in result row $i not true") + } + if (booleanResult(i).getInt(1) != i * 4) { + fail(s"Int value in result row $i should be ${4*i}") + } + } + val query6 = sql("SELECT * FROM testfiltersource WHERE mystring = \"100\"") + assert( + query6.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val stringResult = query6.collect() + assert(stringResult.size === 1) + assert(stringResult(0).getString(2) == "100", "stringvalue incorrect") + assert(stringResult(0).getInt(1) === 100) + } }