Skip to content
Snippets Groups Projects
Commit 5b8bb1b2 authored by Tathagata Das's avatar Tathagata Das
Browse files

[SPARK-9572] [STREAMING] [PYSPARK] Added StreamingContext.getActiveOrCreate() in Python

Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes #8080 from tdas/SPARK-9572 and squashes the following commits:

64a231d [Tathagata Das] Fix based on comments
741a0d0 [Tathagata Das] Fixed style
f4f094c [Tathagata Das] Tweaked test
9afcdbe [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into SPARK-9572
e21488d [Tathagata Das] Minor update
1a371d9 [Tathagata Das] Addressed comments.
60479da [Tathagata Das] Fixed indent
9c2da9c [Tathagata Das] Fixed bugs
b5bd32c [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into SPARK-9572
b55b348 [Tathagata Das] Removed prints
5781728 [Tathagata Das] Fix style issues
b711214 [Tathagata Das] Reverted run-tests.py
643b59d [Tathagata Das] Revert unnecessary change
150e58c [Tathagata Das] Added StreamingContext.getActiveOrCreate() in Python
parent dbd778d8
No related branches found
No related tags found
No related merge requests found
...@@ -86,6 +86,9 @@ class StreamingContext(object): ...@@ -86,6 +86,9 @@ class StreamingContext(object):
""" """
_transformerSerializer = None _transformerSerializer = None
# Reference to a currently active StreamingContext
_activeContext = None
def __init__(self, sparkContext, batchDuration=None, jssc=None): def __init__(self, sparkContext, batchDuration=None, jssc=None):
""" """
Create a new StreamingContext. Create a new StreamingContext.
...@@ -142,10 +145,10 @@ class StreamingContext(object): ...@@ -142,10 +145,10 @@ class StreamingContext(object):
Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
recreated from the checkpoint data. If the data does not exist, then the provided setupFunc recreated from the checkpoint data. If the data does not exist, then the provided setupFunc
will be used to create a JavaStreamingContext. will be used to create a new context.
@param checkpointPath: Checkpoint directory used in an earlier JavaStreamingContext program @param checkpointPath: Checkpoint directory used in an earlier streaming program
@param setupFunc: Function to create a new JavaStreamingContext and setup DStreams @param setupFunc: Function to create a new context and setup DStreams
""" """
# TODO: support checkpoint in HDFS # TODO: support checkpoint in HDFS
if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath): if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath):
...@@ -170,6 +173,52 @@ class StreamingContext(object): ...@@ -170,6 +173,52 @@ class StreamingContext(object):
cls._transformerSerializer.ctx = sc cls._transformerSerializer.ctx = sc
return StreamingContext(sc, None, jssc) return StreamingContext(sc, None, jssc)
@classmethod
def getActive(cls):
"""
Return either the currently active StreamingContext (i.e., if there is a context started
but not stopped) or None.
"""
activePythonContext = cls._activeContext
if activePythonContext is not None:
# Verify that the current running Java StreamingContext is active and is the same one
# backing the supposedly active Python context
activePythonContextJavaId = activePythonContext._jssc.ssc().hashCode()
activeJvmContextOption = activePythonContext._jvm.StreamingContext.getActive()
if activeJvmContextOption.isEmpty():
cls._activeContext = None
elif activeJvmContextOption.get().hashCode() != activePythonContextJavaId:
cls._activeContext = None
raise Exception("JVM's active JavaStreamingContext is not the JavaStreamingContext "
"backing the action Python StreamingContext. This is unexpected.")
return cls._activeContext
@classmethod
def getActiveOrCreate(cls, checkpointPath, setupFunc):
"""
Either return the active StreamingContext (i.e. currently started but not stopped),
or recreate a StreamingContext from checkpoint data or create a new StreamingContext
using the provided setupFunc function. If the checkpointPath is None or does not contain
valid checkpoint data, then setupFunc will be called to create a new context and setup
DStreams.
@param checkpointPath: Checkpoint directory used in an earlier streaming program. Can be
None if the intention is to always create a new context when there
is no active context.
@param setupFunc: Function to create a new JavaStreamingContext and setup DStreams
"""
if setupFunc is None:
raise Exception("setupFunc cannot be None")
activeContext = cls.getActive()
if activeContext is not None:
return activeContext
elif checkpointPath is not None:
return cls.getOrCreate(checkpointPath, setupFunc)
else:
return setupFunc()
@property @property
def sparkContext(self): def sparkContext(self):
""" """
...@@ -182,6 +231,7 @@ class StreamingContext(object): ...@@ -182,6 +231,7 @@ class StreamingContext(object):
Start the execution of the streams. Start the execution of the streams.
""" """
self._jssc.start() self._jssc.start()
StreamingContext._activeContext = self
def awaitTermination(self, timeout=None): def awaitTermination(self, timeout=None):
""" """
...@@ -212,6 +262,7 @@ class StreamingContext(object): ...@@ -212,6 +262,7 @@ class StreamingContext(object):
of all received data to be completed of all received data to be completed
""" """
self._jssc.stop(stopSparkContext, stopGraceFully) self._jssc.stop(stopSparkContext, stopGraceFully)
StreamingContext._activeContext = None
if stopSparkContext: if stopSparkContext:
self._sc.stop() self._sc.stop()
......
...@@ -24,6 +24,7 @@ import operator ...@@ -24,6 +24,7 @@ import operator
import tempfile import tempfile
import random import random
import struct import struct
import shutil
from functools import reduce from functools import reduce
if sys.version_info[:2] <= (2, 6): if sys.version_info[:2] <= (2, 6):
...@@ -59,12 +60,21 @@ class PySparkStreamingTestCase(unittest.TestCase): ...@@ -59,12 +60,21 @@ class PySparkStreamingTestCase(unittest.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
cls.sc.stop() cls.sc.stop()
# Clean up in the JVM just in case there has been some issues in Python API
jSparkContextOption = SparkContext._jvm.SparkContext.get()
if jSparkContextOption.nonEmpty():
jSparkContextOption.get().stop()
def setUp(self): def setUp(self):
self.ssc = StreamingContext(self.sc, self.duration) self.ssc = StreamingContext(self.sc, self.duration)
def tearDown(self): def tearDown(self):
self.ssc.stop(False) if self.ssc is not None:
self.ssc.stop(False)
# Clean up in the JVM just in case there has been some issues in Python API
jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive()
if jStreamingContextOption.nonEmpty():
jStreamingContextOption.get().stop(False)
def wait_for(self, result, n): def wait_for(self, result, n):
start_time = time.time() start_time = time.time()
...@@ -442,6 +452,7 @@ class WindowFunctionTests(PySparkStreamingTestCase): ...@@ -442,6 +452,7 @@ class WindowFunctionTests(PySparkStreamingTestCase):
class StreamingContextTests(PySparkStreamingTestCase): class StreamingContextTests(PySparkStreamingTestCase):
duration = 0.1 duration = 0.1
setupCalled = False
def _add_input_stream(self): def _add_input_stream(self):
inputs = [range(1, x) for x in range(101)] inputs = [range(1, x) for x in range(101)]
...@@ -515,10 +526,85 @@ class StreamingContextTests(PySparkStreamingTestCase): ...@@ -515,10 +526,85 @@ class StreamingContextTests(PySparkStreamingTestCase):
self.assertEqual([2, 3, 1], self._take(dstream, 3)) self.assertEqual([2, 3, 1], self._take(dstream, 3))
def test_get_active(self):
self.assertEqual(StreamingContext.getActive(), None)
# Verify that getActive() returns the active context
self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
self.ssc.start()
self.assertEqual(StreamingContext.getActive(), self.ssc)
# Verify that getActive() returns None
self.ssc.stop(False)
self.assertEqual(StreamingContext.getActive(), None)
# Verify that if the Java context is stopped, then getActive() returns None
self.ssc = StreamingContext(self.sc, self.duration)
self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
self.ssc.start()
self.assertEqual(StreamingContext.getActive(), self.ssc)
self.ssc._jssc.stop(False)
self.assertEqual(StreamingContext.getActive(), None)
def test_get_active_or_create(self):
# Test StreamingContext.getActiveOrCreate() without checkpoint data
# See CheckpointTests for tests with checkpoint data
self.ssc = None
self.assertEqual(StreamingContext.getActive(), None)
def setupFunc():
ssc = StreamingContext(self.sc, self.duration)
ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
self.setupCalled = True
return ssc
# Verify that getActiveOrCreate() (w/o checkpoint) calls setupFunc when no context is active
self.setupCalled = False
self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc)
self.assertTrue(self.setupCalled)
# Verify that getActiveOrCreate() retuns active context and does not call the setupFunc
self.ssc.start()
self.setupCalled = False
self.assertEqual(StreamingContext.getActiveOrCreate(None, setupFunc), self.ssc)
self.assertFalse(self.setupCalled)
# Verify that getActiveOrCreate() calls setupFunc after active context is stopped
self.ssc.stop(False)
self.setupCalled = False
self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc)
self.assertTrue(self.setupCalled)
# Verify that if the Java context is stopped, then getActive() returns None
self.ssc = StreamingContext(self.sc, self.duration)
self.ssc.queueStream([[1]]).foreachRDD(lambda rdd: rdd.count())
self.ssc.start()
self.assertEqual(StreamingContext.getActive(), self.ssc)
self.ssc._jssc.stop(False)
self.setupCalled = False
self.ssc = StreamingContext.getActiveOrCreate(None, setupFunc)
self.assertTrue(self.setupCalled)
class CheckpointTests(unittest.TestCase): class CheckpointTests(unittest.TestCase):
def test_get_or_create(self): setupCalled = False
@staticmethod
def tearDownClass():
# Clean up in the JVM just in case there has been some issues in Python API
jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive()
if jStreamingContextOption.nonEmpty():
jStreamingContextOption.get().stop()
jSparkContextOption = SparkContext._jvm.SparkContext.get()
if jSparkContextOption.nonEmpty():
jSparkContextOption.get().stop()
def tearDown(self):
if self.ssc is not None:
self.ssc.stop(True)
def test_get_or_create_and_get_active_or_create(self):
inputd = tempfile.mkdtemp() inputd = tempfile.mkdtemp()
outputd = tempfile.mkdtemp() + "/" outputd = tempfile.mkdtemp() + "/"
...@@ -533,11 +619,12 @@ class CheckpointTests(unittest.TestCase): ...@@ -533,11 +619,12 @@ class CheckpointTests(unittest.TestCase):
wc = dstream.updateStateByKey(updater) wc = dstream.updateStateByKey(updater)
wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test") wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test")
wc.checkpoint(.5) wc.checkpoint(.5)
self.setupCalled = True
return ssc return ssc
cpd = tempfile.mkdtemp("test_streaming_cps") cpd = tempfile.mkdtemp("test_streaming_cps")
ssc = StreamingContext.getOrCreate(cpd, setup) self.ssc = StreamingContext.getOrCreate(cpd, setup)
ssc.start() self.ssc.start()
def check_output(n): def check_output(n):
while not os.listdir(outputd): while not os.listdir(outputd):
...@@ -552,7 +639,7 @@ class CheckpointTests(unittest.TestCase): ...@@ -552,7 +639,7 @@ class CheckpointTests(unittest.TestCase):
# not finished # not finished
time.sleep(0.01) time.sleep(0.01)
continue continue
ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) ordd = self.ssc.sparkContext.textFile(p).map(lambda line: line.split(","))
d = ordd.values().map(int).collect() d = ordd.values().map(int).collect()
if not d: if not d:
time.sleep(0.01) time.sleep(0.01)
...@@ -568,13 +655,37 @@ class CheckpointTests(unittest.TestCase): ...@@ -568,13 +655,37 @@ class CheckpointTests(unittest.TestCase):
check_output(1) check_output(1)
check_output(2) check_output(2)
ssc.stop(True, True)
# Verify the getOrCreate() recovers from checkpoint files
self.ssc.stop(True, True)
time.sleep(1) time.sleep(1)
ssc = StreamingContext.getOrCreate(cpd, setup) self.setupCalled = False
ssc.start() self.ssc = StreamingContext.getOrCreate(cpd, setup)
self.assertFalse(self.setupCalled)
self.ssc.start()
check_output(3) check_output(3)
ssc.stop(True, True)
# Verify the getActiveOrCreate() recovers from checkpoint files
self.ssc.stop(True, True)
time.sleep(1)
self.setupCalled = False
self.ssc = StreamingContext.getActiveOrCreate(cpd, setup)
self.assertFalse(self.setupCalled)
self.ssc.start()
check_output(4)
# Verify that getActiveOrCreate() returns active context
self.setupCalled = False
self.assertEquals(StreamingContext.getActiveOrCreate(cpd, setup), self.ssc)
self.assertFalse(self.setupCalled)
# Verify that getActiveOrCreate() calls setup() in absence of checkpoint files
self.ssc.stop(True, True)
shutil.rmtree(cpd) # delete checkpoint directory
self.setupCalled = False
self.ssc = StreamingContext.getActiveOrCreate(cpd, setup)
self.assertTrue(self.setupCalled)
self.ssc.stop(True, True)
class KafkaStreamTests(PySparkStreamingTestCase): class KafkaStreamTests(PySparkStreamingTestCase):
...@@ -1134,7 +1245,7 @@ if __name__ == "__main__": ...@@ -1134,7 +1245,7 @@ if __name__ == "__main__":
testcases.append(KinesisStreamTests) testcases.append(KinesisStreamTests)
elif are_kinesis_tests_enabled is False: elif are_kinesis_tests_enabled is False:
sys.stderr.write("Skipping all Kinesis Python tests as the optional Kinesis project was " sys.stderr.write("Skipping all Kinesis Python tests as the optional Kinesis project was "
"not compiled with -Pkinesis-asl profile. To run these tests, " "not compiled into a JAR. To run these tests, "
"you need to build Spark with 'build/sbt -Pkinesis-asl assembly/assembly " "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/assembly "
"streaming-kinesis-asl-assembly/assembly' or " "streaming-kinesis-asl-assembly/assembly' or "
"'build/mvn -Pkinesis-asl package' before running this test.") "'build/mvn -Pkinesis-asl package' before running this test.")
...@@ -1150,4 +1261,4 @@ if __name__ == "__main__": ...@@ -1150,4 +1261,4 @@ if __name__ == "__main__":
for testcase in testcases: for testcase in testcases:
sys.stderr.write("[Running %s]\n" % (testcase)) sys.stderr.write("[Running %s]\n" % (testcase))
tests = unittest.TestLoader().loadTestsFromTestCase(testcase) tests = unittest.TestLoader().loadTestsFromTestCase(testcase)
unittest.TextTestRunner(verbosity=2).run(tests) unittest.TextTestRunner(verbosity=3).run(tests)
...@@ -158,7 +158,7 @@ def main(): ...@@ -158,7 +158,7 @@ def main():
else: else:
log_level = logging.INFO log_level = logging.INFO
logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s") logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s")
LOGGER.info("Running PySpark tests. Output is in python/%s", LOG_FILE) LOGGER.info("Running PySpark tests. Output is in %s", LOG_FILE)
if os.path.exists(LOG_FILE): if os.path.exists(LOG_FILE):
os.remove(LOG_FILE) os.remove(LOG_FILE)
python_execs = opts.python_executables.split(',') python_execs = opts.python_executables.split(',')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment