Skip to content
Snippets Groups Projects
Commit 9ab725ea authored by Tathagata Das's avatar Tathagata Das
Browse files

[SPARK-18758][SS] StreamingQueryListener events from a StreamingQuery should...

[SPARK-18758][SS] StreamingQueryListener events from a StreamingQuery should be sent only to the listeners in the same session as the query

## What changes were proposed in this pull request?

Listeners added with `sparkSession.streams.addListener(l)` are added to a SparkSession. So events only from queries in the same session as a listener should be posted to the listener. Currently, all the events gets rerouted through the Spark's main listener bus, that is,
- StreamingQuery posts event to StreamingQueryListenerBus. Only the queries associated with the same session as the bus posts events to it.
- StreamingQueryListenerBus posts event to Spark's main LiveListenerBus as a SparkEvent.
- StreamingQueryListenerBus also subscribes to LiveListenerBus events thus getting back the posted event in a different thread.
- The received is posted to the registered listeners.

The problem is that *all StreamingQueryListenerBuses in all sessions* gets the events and posts them to their listeners. This is wrong.

In this PR, I solve it by making StreamingQueryListenerBus track active queries (by their runIds) when a query posts the QueryStarted event to the bus. This allows the rerouted events to be filtered using the tracked queries.

Note that this list needs to be maintained separately
from the `StreamingQueryManager.activeQueries` because a terminated query is cleared from
`StreamingQueryManager.activeQueries` as soon as it is stopped, but the this ListenerBus must
clear a query only after the termination event of that query has been posted lazily, much after the query has been terminated.

Credit goes to zsxwing for coming up with the initial idea.

## How was this patch tested?
Updated test harness code to use the correct session, and added new unit test.

Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes #16186 from tdas/SPARK-18758.
parent aad11209
No related branches found
No related tags found
No related merge requests found
......@@ -17,6 +17,10 @@
package org.apache.spark.sql.execution.streaming
import java.util.UUID
import scala.collection.mutable
import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerEvent}
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.util.ListenerBus
......@@ -25,7 +29,11 @@ import org.apache.spark.util.ListenerBus
* A bus to forward events to [[StreamingQueryListener]]s. This one will send received
* [[StreamingQueryListener.Event]]s to the Spark listener bus. It also registers itself with
* Spark listener bus, so that it can receive [[StreamingQueryListener.Event]]s and dispatch them
* to StreamingQueryListener.
* to StreamingQueryListeners.
*
* Note that each bus and its registered listeners are associated with a single SparkSession
* and StreamingQueryManager. So this bus will dispatch events to registered listeners for only
* those queries that were started in the associated SparkSession.
*/
class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus)
extends SparkListener with ListenerBus[StreamingQueryListener, StreamingQueryListener.Event] {
......@@ -35,12 +43,30 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus)
sparkListenerBus.addListener(this)
/**
* Post a StreamingQueryListener event to the Spark listener bus asynchronously. This event will
* be dispatched to all StreamingQueryListener in the thread of the Spark listener bus.
* RunIds of active queries whose events are supposed to be forwarded by this ListenerBus
* to registered `StreamingQueryListeners`.
*
* Note 1: We need to track runIds instead of ids because the runId is unique for every started
* query, even it its a restart. So even if a query is restarted, this bus will identify them
* separately and correctly account for the restart.
*
* Note 2: This list needs to be maintained separately from the
* `StreamingQueryManager.activeQueries` because a terminated query is cleared from
* `StreamingQueryManager.activeQueries` as soon as it is stopped, but the this ListenerBus
* must clear a query only after the termination event of that query has been posted.
*/
private val activeQueryRunIds = new mutable.HashSet[UUID]
/**
* Post a StreamingQueryListener event to the added StreamingQueryListeners.
* Note that only the QueryStarted event is posted to the listener synchronously. Other events
* are dispatched to Spark listener bus. This method is guaranteed to be called by queries in
* the same SparkSession as this listener.
*/
def post(event: StreamingQueryListener.Event) {
event match {
case s: QueryStartedEvent =>
activeQueryRunIds.synchronized { activeQueryRunIds += s.runId }
sparkListenerBus.post(s)
// post to local listeners to trigger callbacks
postToAll(s)
......@@ -63,18 +89,32 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus)
}
}
/**
* Dispatch events to registered StreamingQueryListeners. Only the events associated queries
* started in the same SparkSession as this ListenerBus will be dispatched to the listeners.
*/
override protected def doPostEvent(
listener: StreamingQueryListener,
event: StreamingQueryListener.Event): Unit = {
def shouldReport(runId: UUID): Boolean = {
activeQueryRunIds.synchronized { activeQueryRunIds.contains(runId) }
}
event match {
case queryStarted: QueryStartedEvent =>
listener.onQueryStarted(queryStarted)
if (shouldReport(queryStarted.runId)) {
listener.onQueryStarted(queryStarted)
}
case queryProgress: QueryProgressEvent =>
listener.onQueryProgress(queryProgress)
if (shouldReport(queryProgress.progress.runId)) {
listener.onQueryProgress(queryProgress)
}
case queryTerminated: QueryTerminatedEvent =>
listener.onQueryTerminated(queryTerminated)
if (shouldReport(queryTerminated.runId)) {
listener.onQueryTerminated(queryTerminated)
activeQueryRunIds.synchronized { activeQueryRunIds -= queryTerminated.runId }
}
case _ =>
}
}
}
......@@ -70,11 +70,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
def schema: StructType = encoder.schema
def toDS()(implicit sqlContext: SQLContext): Dataset[A] = {
def toDS(): Dataset[A] = {
Dataset(sqlContext.sparkSession, logicalPlan)
}
def toDF()(implicit sqlContext: SQLContext): DataFrame = {
def toDF(): DataFrame = {
Dataset.ofRows(sqlContext.sparkSession, logicalPlan)
}
......
......@@ -231,8 +231,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): Unit = {
val stream = _stream.toDF()
val sparkSession = stream.sparkSession // use the session in DF, not the default session
var pos = 0
var currentPlan: LogicalPlan = stream.logicalPlan
var currentStream: StreamExecution = null
var lastStream: StreamExecution = null
val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for
......@@ -319,7 +319,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
""".stripMargin)
}
val testThread = Thread.currentThread()
val metadataRoot = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
var manualClockExpectedTime = -1L
try {
......@@ -337,14 +336,16 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
additionalConfs.foreach(pair => {
val value =
if (spark.conf.contains(pair._1)) Some(spark.conf.get(pair._1)) else None
if (sparkSession.conf.contains(pair._1)) {
Some(sparkSession.conf.get(pair._1))
} else None
resetConfValues(pair._1) = value
spark.conf.set(pair._1, pair._2)
sparkSession.conf.set(pair._1, pair._2)
})
lastStream = currentStream
currentStream =
spark
sparkSession
.streams
.startQuery(
None,
......@@ -518,8 +519,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
// Rollback prev configuration values
resetConfValues.foreach {
case (key, Some(value)) => spark.conf.set(key, value)
case (key, None) => spark.conf.unset(key)
case (key, Some(value)) => sparkSession.conf.set(key, value)
case (key, None) => sparkSession.conf.unset(key)
}
}
}
......
......@@ -20,6 +20,7 @@ package org.apache.spark.sql.streaming
import java.util.UUID
import scala.collection.mutable
import scala.concurrent.duration._
import org.scalactic.TolerantNumerics
import org.scalatest.concurrent.AsyncAssertions.Waiter
......@@ -30,6 +31,7 @@ import org.scalatest.PrivateMethodTester._
import org.apache.spark.SparkException
import org.apache.spark.scheduler._
import org.apache.spark.sql.{Encoder, SparkSession}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.StreamingQueryListener._
......@@ -45,7 +47,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
after {
spark.streams.active.foreach(_.stop())
assert(spark.streams.active.isEmpty)
assert(addedListeners.isEmpty)
assert(addedListeners().isEmpty)
// Make sure we don't leak any events to the next test
spark.sparkContext.listenerBus.waitUntilEmpty(10000)
}
......@@ -148,7 +150,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
assert(isListenerActive(listener1) === false)
assert(isListenerActive(listener2) === true)
} finally {
addedListeners.foreach(spark.streams.removeListener)
addedListeners().foreach(spark.streams.removeListener)
}
}
......@@ -251,6 +253,57 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
}
}
test("listener only posts events from queries started in the related sessions") {
val session1 = spark.newSession()
val session2 = spark.newSession()
val collector1 = new EventCollector
val collector2 = new EventCollector
def runQuery(session: SparkSession): Unit = {
collector1.reset()
collector2.reset()
val mem = MemoryStream[Int](implicitly[Encoder[Int]], session.sqlContext)
testStream(mem.toDS)(
AddData(mem, 1, 2, 3),
CheckAnswer(1, 2, 3)
)
session.sparkContext.listenerBus.waitUntilEmpty(5000)
}
def assertEventsCollected(collector: EventCollector): Unit = {
assert(collector.startEvent !== null)
assert(collector.progressEvents.nonEmpty)
assert(collector.terminationEvent !== null)
}
def assertEventsNotCollected(collector: EventCollector): Unit = {
assert(collector.startEvent === null)
assert(collector.progressEvents.isEmpty)
assert(collector.terminationEvent === null)
}
assert(session1.ne(session2))
assert(session1.streams.ne(session2.streams))
withListenerAdded(collector1, session1) {
assert(addedListeners(session1).nonEmpty)
withListenerAdded(collector2, session2) {
assert(addedListeners(session2).nonEmpty)
// query on session1 should send events only to collector1
runQuery(session1)
assertEventsCollected(collector1)
assertEventsNotCollected(collector2)
// query on session2 should send events only to collector2
runQuery(session2)
assertEventsCollected(collector2)
assertEventsNotCollected(collector1)
}
}
}
testQuietly("ReplayListenerBus should ignore broken event jsons generated in 2.0.0") {
// query-event-logs-version-2.0.0.txt has all types of events generated by
// Structured Streaming in Spark 2.0.0.
......@@ -298,21 +351,23 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
}
}
private def withListenerAdded(listener: StreamingQueryListener)(body: => Unit): Unit = {
private def withListenerAdded(
listener: StreamingQueryListener,
session: SparkSession = spark)(body: => Unit): Unit = {
try {
failAfter(streamingTimeout) {
spark.streams.addListener(listener)
session.streams.addListener(listener)
body
}
} finally {
spark.streams.removeListener(listener)
session.streams.removeListener(listener)
}
}
private def addedListeners(): Array[StreamingQueryListener] = {
private def addedListeners(session: SparkSession = spark): Array[StreamingQueryListener] = {
val listenerBusMethod =
PrivateMethod[StreamingQueryListenerBus]('listenerBus)
val listenerBus = spark.streams invokePrivate listenerBusMethod()
val listenerBus = session.streams invokePrivate listenerBusMethod()
listenerBus.listeners.toArray.map(_.asInstanceOf[StreamingQueryListener])
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment