From fccb337f9d1e44a83cfcc00ce33eae1fad367695 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh <viirya@gmail.com> Date: Fri, 17 Nov 2017 17:43:40 +0100 Subject: [PATCH] [SPARK-22538][ML] SQLTransformer should not unpersist possibly cached input dataset ## What changes were proposed in this pull request? `SQLTransformer.transform` unpersists input dataset when dropping temporary view. We should not change input dataset's cache status. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #19772 from viirya/SPARK-22538. --- .../org/apache/spark/ml/feature/SQLTransformer.scala | 3 ++- .../spark/ml/feature/SQLTransformerSuite.scala | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index 62c1972aab..0fb1d8c5dc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -70,7 +70,8 @@ class SQLTransformer @Since("1.6.0") (@Since("1.6.0") override val uid: String) dataset.createOrReplaceTempView(tableName) val realStatement = $(statement).replace(tableIdentifier, tableName) val result = dataset.sparkSession.sql(realStatement) - dataset.sparkSession.catalog.dropTempView(tableName) + // Call SessionCatalog.dropTempView to avoid unpersisting the possibly cached dataset. + dataset.sparkSession.sessionState.catalog.dropTempView(tableName) result } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala index 753f890c48..673a146e61 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.types.{LongType, StructField, StructType} +import org.apache.spark.storage.StorageLevel class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -60,4 +61,15 @@ class SQLTransformerSuite val expected = StructType(Seq(StructField("id1", LongType, nullable = false))) assert(outputSchema === expected) } + + test("SPARK-22538: SQLTransformer should not unpersist given dataset") { + val df = spark.range(10) + df.cache() + df.count() + assert(df.storageLevel != StorageLevel.NONE) + new SQLTransformer() + .setStatement("SELECT id + 1 AS id1 FROM __THIS__") + .transform(df) + assert(df.storageLevel != StorageLevel.NONE) + } } -- GitLab