Skip to content
Snippets Groups Projects
Commit 1ff0580e authored by Reynold Xin's avatar Reynold Xin
Browse files

[SPARK-10093] [SPARK-10096] [SQL] Avoid transformation on executors & fix UDFs on complex types

This is kind of a weird case, but given a sufficiently complex query plan (in this case a TungstenProject with an Exchange underneath), we could have NPEs on the executors due to the time when we were calling transformAllExpressions

In general we should ensure that all transformations occur on the driver and not on the executors. Some reasons for avoid executor side transformations include:

* (this case) Some operator constructors require state such as access to the Spark/SQL conf so doing a makeCopy on the executor can fail.
* (unrelated reason for avoid executor transformations) ExprIds are calculated using an atomic integer, so you can violate their uniqueness constraint by constructing them anywhere other than the driver.

This subsumes #8285.

Author: Reynold Xin <rxin@databricks.com>
Author: Michael Armbrust <michael@databricks.com>

Closes #8295 from rxin/SPARK-10096.
parent 270ee677
No related branches found
No related tags found
No related merge requests found
...@@ -206,7 +206,9 @@ case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { ...@@ -206,7 +206,9 @@ case class CreateStructUnsafe(children: Seq[Expression]) extends Expression {
override def nullable: Boolean = false override def nullable: Boolean = false
override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def eval(input: InternalRow): Any = {
InternalRow(children.map(_.eval(input)): _*)
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval = GenerateUnsafeProjection.createCode(ctx, children) val eval = GenerateUnsafeProjection.createCode(ctx, children)
...@@ -244,7 +246,9 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression ...@@ -244,7 +246,9 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression
override def nullable: Boolean = false override def nullable: Boolean = false
override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def eval(input: InternalRow): Any = {
InternalRow(valExprs.map(_.eval(input)): _*)
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) val eval = GenerateUnsafeProjection.createCode(ctx, valExprs)
......
...@@ -75,14 +75,16 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) ...@@ -75,14 +75,16 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan)
override def output: Seq[Attribute] = projectList.map(_.toAttribute) override def output: Seq[Attribute] = projectList.map(_.toAttribute)
/** Rewrite the project list to use unsafe expressions as needed. */
protected val unsafeProjectList = projectList.map(_ transform {
case CreateStruct(children) => CreateStructUnsafe(children)
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
})
protected override def doExecute(): RDD[InternalRow] = { protected override def doExecute(): RDD[InternalRow] = {
val numRows = longMetric("numRows") val numRows = longMetric("numRows")
child.execute().mapPartitions { iter => child.execute().mapPartitions { iter =>
this.transformAllExpressions { val project = UnsafeProjection.create(unsafeProjectList, child.output)
case CreateStruct(children) => CreateStructUnsafe(children)
case CreateNamedStruct(children) => CreateNamedStructUnsafe(children)
}
val project = UnsafeProjection.create(projectList, child.output)
iter.map { row => iter.map { row =>
numRows += 1 numRows += 1
project(row) project(row)
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
/**
* A test suite to test DataFrame/SQL functionalities with complex types (i.e. array, struct, map).
*/
class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("UDF on struct") {
val f = udf((a: String) => a)
val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
df.select(struct($"a").as("s")).select(f($"s.a")).collect()
}
test("UDF on named_struct") {
val f = udf((a: String) => a)
val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
df.selectExpr("named_struct('a', a) s").select(f($"s.a")).collect()
}
test("UDF on array") {
val f = udf((a: String) => a)
val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect()
}
}
...@@ -878,4 +878,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ...@@ -878,4 +878,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val df = Seq(("x", (1, 1)), ("y", (2, 2))).toDF("a", "b") val df = Seq(("x", (1, 1)), ("y", (2, 2))).toDF("a", "b")
checkAnswer(df.groupBy("b._1").agg(sum("b._2")), Row(1, 1) :: Row(2, 2) :: Nil) checkAnswer(df.groupBy("b._1").agg(sum("b._2")), Row(1, 1) :: Row(2, 2) :: Nil)
} }
test("SPARK-10093: Avoid transformations on executors") {
val df = Seq((1, 1)).toDF("a", "b")
df.where($"a" === 1)
.select($"a", $"b", struct($"b"))
.orderBy("a")
.select(struct($"b"))
.collect()
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment