From 142f6d14928c780cc9e8d6d7749c5d7c08a30972 Mon Sep 17 00:00:00 2001
From: Kunal Khamar <kkhamar@outlook.com>
Date: Wed, 29 Mar 2017 12:35:19 -0700
Subject: [PATCH] [SPARK-20048][SQL] Cloning SessionState does not clone query
 execution listeners

## What changes were proposed in this pull request?

Bugfix from [SPARK-19540.](https://github.com/apache/spark/pull/16826)
Cloning SessionState does not clone query execution listeners, so cloned session is unable to listen to events on queries.

## How was this patch tested?

- Unit test

Author: Kunal Khamar <kkhamar@outlook.com>

Closes #17379 from kunalkhamar/clone-bugfix.
---
 .../org/apache/spark/sql/SparkSession.scala   | 22 ++++----
 ...rs.scala => BaseSessionStateBuilder.scala} | 24 ++++++++-
 .../spark/sql/internal/SessionState.scala     | 38 ++++---------
 .../sql/util/QueryExecutionListener.scala     | 10 ++++
 .../apache/spark/sql/SessionStateSuite.scala  | 53 +++++++++++++++++++
 .../hive/thriftserver/SparkSQLCLIDriver.scala |  2 +-
 ...te.scala => HiveSessionStateBuilder.scala} | 14 +----
 .../apache/spark/sql/hive/test/TestHive.scala |  2 +-
 .../sql/hive/HiveSessionStateSuite.scala      |  2 +-
 9 files changed, 111 insertions(+), 56 deletions(-)
 rename sql/core/src/main/scala/org/apache/spark/sql/internal/{sessionStateBuilders.scala => BaseSessionStateBuilder.scala} (92%)
 rename sql/hive/src/main/scala/org/apache/spark/sql/hive/{HiveSessionState.scala => HiveSessionStateBuilder.scala} (92%)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 49562578b2..a97297892b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range}
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.apache.spark.sql.execution.ui.SQLListener
-import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState}
+import org.apache.spark.sql.internal.{BaseSessionStateBuilder, CatalogImpl, SessionState, SessionStateBuilder, SharedState}
 import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
 import org.apache.spark.sql.sources.BaseRelation
 import org.apache.spark.sql.streaming._
@@ -194,7 +194,7 @@ class SparkSession private(
    *
    * @since 2.0.0
    */
-  def udf: UDFRegistration = sessionState.udf
+  def udf: UDFRegistration = sessionState.udfRegistration
 
   /**
    * :: Experimental ::
@@ -990,28 +990,28 @@ object SparkSession {
   /** Reference to the root SparkSession. */
   private val defaultSession = new AtomicReference[SparkSession]
 
-  private val HIVE_SESSION_STATE_CLASS_NAME = "org.apache.spark.sql.hive.HiveSessionState"
+  private val HIVE_SESSION_STATE_BUILDER_CLASS_NAME =
+    "org.apache.spark.sql.hive.HiveSessionStateBuilder"
 
   private def sessionStateClassName(conf: SparkConf): String = {
     conf.get(CATALOG_IMPLEMENTATION) match {
-      case "hive" => HIVE_SESSION_STATE_CLASS_NAME
-      case "in-memory" => classOf[SessionState].getCanonicalName
+      case "hive" => HIVE_SESSION_STATE_BUILDER_CLASS_NAME
+      case "in-memory" => classOf[SessionStateBuilder].getCanonicalName
     }
   }
 
   /**
    * Helper method to create an instance of `SessionState` based on `className` from conf.
-   * The result is either `SessionState` or `HiveSessionState`.
+   * The result is either `SessionState` or a Hive based `SessionState`.
    */
   private def instantiateSessionState(
       className: String,
       sparkSession: SparkSession): SessionState = {
-
     try {
-      // get `SessionState.apply(SparkSession)`
+      // invoke `new [Hive]SessionStateBuilder(SparkSession, Option[SessionState])`
       val clazz = Utils.classForName(className)
-      val method = clazz.getMethod("apply", sparkSession.getClass)
-      method.invoke(null, sparkSession).asInstanceOf[SessionState]
+      val ctor = clazz.getConstructors.head
+      ctor.newInstance(sparkSession, None).asInstanceOf[BaseSessionStateBuilder].build()
     } catch {
       case NonFatal(e) =>
         throw new IllegalArgumentException(s"Error while instantiating '$className':", e)
@@ -1023,7 +1023,7 @@ object SparkSession {
    */
   private[spark] def hiveClassesArePresent: Boolean = {
     try {
-      Utils.classForName(HIVE_SESSION_STATE_CLASS_NAME)
+      Utils.classForName(HIVE_SESSION_STATE_BUILDER_CLASS_NAME)
       Utils.classForName("org.apache.hadoop.hive.conf.HiveConf")
       true
     } catch {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
similarity index 92%
rename from sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index b8f645fdee..2b14eca919 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/sessionStateBuilders.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.internal
 
 import org.apache.spark.SparkConf
 import org.apache.spark.annotation.{Experimental, InterfaceStability}
-import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy}
+import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy, UDFRegistration}
 import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
 import org.apache.spark.sql.catalyst.catalog.SessionCatalog
 import org.apache.spark.sql.catalyst.optimizer.Optimizer
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.{QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser}
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.streaming.StreamingQueryManager
+import org.apache.spark.sql.util.ExecutionListenerManager
 
 /**
  * Builder class that coordinates construction of a new [[SessionState]].
@@ -133,6 +134,14 @@ abstract class BaseSessionStateBuilder(
     catalog
   }
 
+  /**
+   * Interface exposed to the user for registering user-defined functions.
+   *
+   * Note 1: The user-defined functions must be deterministic.
+   * Note 2: This depends on the `functionRegistry` field.
+   */
+  protected def udfRegistration: UDFRegistration = new UDFRegistration(functionRegistry)
+
   /**
    * Logical query plan analyzer for resolving unresolved attributes and relations.
    *
@@ -232,6 +241,16 @@ abstract class BaseSessionStateBuilder(
    */
   protected def streamingQueryManager: StreamingQueryManager = new StreamingQueryManager(session)
 
+  /**
+   * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
+   * that listen for execution metrics.
+   *
+   * This gets cloned from parent if available, otherwise is a new instance is created.
+   */
+  protected def listenerManager: ExecutionListenerManager = {
+    parentState.map(_.listenerManager.clone()).getOrElse(new ExecutionListenerManager)
+  }
+
   /**
    * Function used to make clones of the session state.
    */
@@ -245,17 +264,18 @@ abstract class BaseSessionStateBuilder(
    */
   def build(): SessionState = {
     new SessionState(
-      session.sparkContext,
       session.sharedState,
       conf,
       experimentalMethods,
       functionRegistry,
+      udfRegistration,
       catalog,
       sqlParser,
       analyzer,
       optimizer,
       planner,
       streamingQueryManager,
+      listenerManager,
       resourceLoader,
       createQueryExecution,
       createClone)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index c6241d923d..1b341a12fc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -32,43 +32,46 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.streaming.StreamingQueryManager
-import org.apache.spark.sql.util.ExecutionListenerManager
+import org.apache.spark.sql.util.{ExecutionListenerManager, QueryExecutionListener}
 
 /**
  * A class that holds all session-specific state in a given [[SparkSession]].
  *
- * @param sparkContext The [[SparkContext]].
- * @param sharedState The shared state.
+ * @param sharedState The state shared across sessions, e.g. global view manager, external catalog.
  * @param conf SQL-specific key-value configurations.
- * @param experimentalMethods The experimental methods.
+ * @param experimentalMethods Interface to add custom planning strategies and optimizers.
  * @param functionRegistry Internal catalog for managing functions registered by the user.
+ * @param udfRegistration Interface exposed to the user for registering user-defined functions.
  * @param catalog Internal catalog for managing table and database states.
  * @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
  * @param analyzer Logical query plan analyzer for resolving unresolved attributes and relations.
  * @param optimizer Logical query plan optimizer.
- * @param planner Planner that converts optimized logical plans to physical plans
+ * @param planner Planner that converts optimized logical plans to physical plans.
  * @param streamingQueryManager Interface to start and stop streaming queries.
+ * @param listenerManager Interface to register custom [[QueryExecutionListener]]s.
+ * @param resourceLoader Session shared resource loader to load JARs, files, etc.
  * @param createQueryExecution Function used to create QueryExecution objects.
  * @param createClone Function used to create clones of the session state.
  */
 private[sql] class SessionState(
-    sparkContext: SparkContext,
     sharedState: SharedState,
     val conf: SQLConf,
     val experimentalMethods: ExperimentalMethods,
     val functionRegistry: FunctionRegistry,
+    val udfRegistration: UDFRegistration,
     val catalog: SessionCatalog,
     val sqlParser: ParserInterface,
     val analyzer: Analyzer,
     val optimizer: Optimizer,
     val planner: SparkPlanner,
     val streamingQueryManager: StreamingQueryManager,
+    val listenerManager: ExecutionListenerManager,
     val resourceLoader: SessionResourceLoader,
     createQueryExecution: LogicalPlan => QueryExecution,
     createClone: (SparkSession, SessionState) => SessionState) {
 
   def newHadoopConf(): Configuration = SessionState.newHadoopConf(
-    sparkContext.hadoopConfiguration,
+    sharedState.sparkContext.hadoopConfiguration,
     conf)
 
   def newHadoopConfWithOptions(options: Map[String, String]): Configuration = {
@@ -81,18 +84,6 @@ private[sql] class SessionState(
     hadoopConf
   }
 
-  /**
-   * Interface exposed to the user for registering user-defined functions.
-   * Note that the user-defined functions must be deterministic.
-   */
-  val udf: UDFRegistration = new UDFRegistration(functionRegistry)
-
-  /**
-   * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
-   * that listen for execution metrics.
-   */
-  val listenerManager: ExecutionListenerManager = new ExecutionListenerManager
-
   /**
    * Get an identical copy of the `SessionState` and associate it with the given `SparkSession`
    */
@@ -110,13 +101,6 @@ private[sql] class SessionState(
 }
 
 private[sql] object SessionState {
-  /**
-   * Create a new [[SessionState]] for the given session.
-   */
-  def apply(session: SparkSession): SessionState = {
-    new SessionStateBuilder(session).build()
-  }
-
   def newHadoopConf(hadoopConf: Configuration, sqlConf: SQLConf): Configuration = {
     val newHadoopConf = new Configuration(hadoopConf)
     sqlConf.getAllConfs.foreach { case (k, v) => if (v ne null) newHadoopConf.set(k, v) }
@@ -155,7 +139,7 @@ class SessionResourceLoader(session: SparkSession) extends FunctionResourceLoade
   /**
    * Add a jar path to [[SparkContext]] and the classloader.
    *
-   * Note: this method seems not access any session state, but the subclass `HiveSessionState` needs
+   * Note: this method seems not access any session state, but a Hive based `SessionState` needs
    * to add the jar to its hive client for the current session. Hence, it still needs to be in
    * [[SessionState]].
    */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala
index 26ad0eadd9..f6240d85fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala
@@ -98,6 +98,16 @@ class ExecutionListenerManager private[sql] () extends Logging {
     listeners.clear()
   }
 
+  /**
+   * Get an identical copy of this listener manager.
+   */
+  @DeveloperApi
+  override def clone(): ExecutionListenerManager = writeLock {
+    val newListenerManager = new ExecutionListenerManager
+    listeners.foreach(newListenerManager.register)
+    newListenerManager
+  }
+
   private[sql] def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
     readLock {
       withErrorHandling { listener =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala
index 2d5e37242a..5638c8eeda 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala
@@ -19,10 +19,13 @@ package org.apache.spark.sql
 
 import org.scalatest.BeforeAndAfterAll
 import org.scalatest.BeforeAndAfterEach
+import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.util.QueryExecutionListener
 
 class SessionStateSuite extends SparkFunSuite
     with BeforeAndAfterEach with BeforeAndAfterAll {
@@ -122,6 +125,56 @@ class SessionStateSuite extends SparkFunSuite
     }
   }
 
+  test("fork new session and inherit listener manager") {
+    class CommandCollector extends QueryExecutionListener {
+      val commands: ArrayBuffer[String] = ArrayBuffer.empty[String]
+      override def onFailure(funcName: String, qe: QueryExecution, ex: Exception) : Unit = {}
+      override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
+        commands += funcName
+      }
+    }
+    val collectorA = new CommandCollector
+    val collectorB = new CommandCollector
+    val collectorC = new CommandCollector
+
+    try {
+      def runCollectQueryOn(sparkSession: SparkSession): Unit = {
+        val tupleEncoder = Encoders.tuple(Encoders.scalaInt, Encoders.STRING)
+        val df = sparkSession.createDataset(Seq(1 -> "a"))(tupleEncoder).toDF("i", "j")
+        df.select("i").collect()
+      }
+
+      activeSession.listenerManager.register(collectorA)
+      val forkedSession = activeSession.cloneSession()
+
+      // inheritance
+      assert(forkedSession ne activeSession)
+      assert(forkedSession.listenerManager ne activeSession.listenerManager)
+      runCollectQueryOn(forkedSession)
+      assert(collectorA.commands.length == 1) // forked should callback to A
+      assert(collectorA.commands(0) == "collect")
+
+      // independence
+      // => changes to forked do not affect original
+      forkedSession.listenerManager.register(collectorB)
+      runCollectQueryOn(activeSession)
+      assert(collectorB.commands.isEmpty) // original should not callback to B
+      assert(collectorA.commands.length == 2) // original should still callback to A
+      assert(collectorA.commands(1) == "collect")
+      // <= changes to original do not affect forked
+      activeSession.listenerManager.register(collectorC)
+      runCollectQueryOn(forkedSession)
+      assert(collectorC.commands.isEmpty) // forked should not callback to C
+      assert(collectorA.commands.length == 3) // forked should still callback to A
+      assert(collectorB.commands.length == 1) // forked should still callback to B
+      assert(collectorA.commands(2) == "collect")
+      assert(collectorB.commands(0) == "collect")
+    } finally {
+      activeSession.listenerManager.unregister(collectorA)
+      activeSession.listenerManager.unregister(collectorC)
+    }
+  }
+
   test("fork new sessions and run query on inherited table") {
     def checkTableExists(sparkSession: SparkSession): Unit = {
       QueryTest.checkAnswer(sparkSession.sql(
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
index 0c79b6f421..390b9b6d68 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
@@ -38,7 +38,7 @@ import org.apache.thrift.transport.TSocket
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.hive.{HiveSessionState, HiveUtils}
+import org.apache.spark.sql.hive.HiveUtils
 import org.apache.spark.util.ShutdownHookManager
 
 /**
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
similarity index 92%
rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
index f49e6bb418..8048c2ba2c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
@@ -28,19 +28,7 @@ import org.apache.spark.sql.hive.client.HiveClient
 import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState}
 
 /**
- * Entry object for creating a Hive aware [[SessionState]].
- */
-private[hive] object HiveSessionState {
-  /**
-   * Create a new Hive aware [[SessionState]]. for the given session.
-   */
-  def apply(session: SparkSession): SessionState = {
-    new HiveSessionStateBuilder(session).build()
-  }
-}
-
-/**
- * Builder that produces a [[HiveSessionState]].
+ * Builder that produces a Hive aware [[SessionState]].
  */
 @Experimental
 @InterfaceStability.Unstable
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 0bcf219922..d9bb1f8c7e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -39,7 +39,7 @@ import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.execution.command.CacheTableCommand
 import org.apache.spark.sql.hive._
 import org.apache.spark.sql.hive.client.HiveClient
-import org.apache.spark.sql.internal._
+import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf, WithTestConf}
 import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
 import org.apache.spark.util.{ShutdownHookManager, Utils}
 
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala
index 67c77fb62f..958ad3e1c3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql._
 import org.apache.spark.sql.hive.test.TestHiveSingleton
 
 /**
- * Run all tests from `SessionStateSuite` with a `HiveSessionState`.
+ * Run all tests from `SessionStateSuite` with a Hive based `SessionState`.
  */
 class HiveSessionStateSuite extends SessionStateSuite
   with TestHiveSingleton with BeforeAndAfterEach {
-- 
GitLab