From 25b5b867d5e18bac1c5bcdc6f8c63d97858194c7 Mon Sep 17 00:00:00 2001
From: Matthew Farrellee <matt@redhat.com>
Date: Tue, 9 Sep 2014 18:54:54 -0700
Subject: [PATCH] [SPARK-3458] enable python "with" statements for SparkContext

allow for best practice code,

```
try:
  sc = SparkContext()
  app(sc)
finally:
  sc.stop()
```

to be written using a "with" statement,

```
with SparkContext() as sc:
  app(sc)
```

Author: Matthew Farrellee <matt@redhat.com>

Closes #2335 from mattf/SPARK-3458 and squashes the following commits:

5b4e37c [Matthew Farrellee] [SPARK-3458] enable python "with" statements for SparkContext
---
 python/pyspark/context.py | 14 ++++++++++++++
 python/pyspark/tests.py   | 29 +++++++++++++++++++++++++++++
 2 files changed, 43 insertions(+)

diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 5a30431568..84bc0a3b7c 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -232,6 +232,20 @@ class SparkContext(object):
                 else:
                     SparkContext._active_spark_context = instance
 
+    def __enter__(self):
+        """
+        Enable 'with SparkContext(...) as sc: app(sc)' syntax.
+        """
+        return self
+
+    def __exit__(self, type, value, trace):
+        """
+        Enable 'with SparkContext(...) as sc: app' syntax.
+
+        Specifically stop the context on exit of the with block.
+        """
+        self.stop()
+
     @classmethod
     def setSystemProperty(cls, key, value):
         """
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 0bd2a9e6c5..bb84ebe72c 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1254,6 +1254,35 @@ class TestSparkSubmit(unittest.TestCase):
         self.assertIn("[2, 4, 6]", out)
 
 
+class ContextStopTests(unittest.TestCase):
+
+    def test_stop(self):
+        sc = SparkContext()
+        self.assertNotEqual(SparkContext._active_spark_context, None)
+        sc.stop()
+        self.assertEqual(SparkContext._active_spark_context, None)
+
+    def test_with(self):
+        with SparkContext() as sc:
+            self.assertNotEqual(SparkContext._active_spark_context, None)
+        self.assertEqual(SparkContext._active_spark_context, None)
+
+    def test_with_exception(self):
+        try:
+            with SparkContext() as sc:
+                self.assertNotEqual(SparkContext._active_spark_context, None)
+                raise Exception()
+        except:
+            pass
+        self.assertEqual(SparkContext._active_spark_context, None)
+
+    def test_with_stop(self):
+        with SparkContext() as sc:
+            self.assertNotEqual(SparkContext._active_spark_context, None)
+            sc.stop()
+        self.assertEqual(SparkContext._active_spark_context, None)
+
+
 @unittest.skipIf(not _have_scipy, "SciPy not installed")
 class SciPyTests(PySparkTestCase):
 
-- 
GitLab