From 93c425327560e386aa23b42387cea24eae228e79 Mon Sep 17 00:00:00 2001
From: Kay Ousterhout <kayousterhout@gmail.com>
Date: Wed, 11 Sep 2013 12:12:20 -0700
Subject: [PATCH] Changed localProperties to use ThreadLocal (not
 DynamicVariable).

The fact that DynamicVariable uses an InheritableThreadLocal
can cause problems where the properties end up being shared
across threads in certain circumstances.
---
 .../scala/org/apache/spark/SparkContext.scala  | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 29407bcd30..72540c712a 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -27,7 +27,6 @@ import scala.collection.generic.Growable
 import scala.collection.JavaConversions._
 import scala.collection.mutable.ArrayBuffer
 import scala.collection.mutable.HashMap
-import scala.util.DynamicVariable
 
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.Path
@@ -257,20 +256,20 @@ class SparkContext(
   private[spark] var checkpointDir: Option[String] = None
 
   // Thread Local variable that can be used by users to pass information down the stack
-  private val localProperties = new DynamicVariable[Properties](null)
+  private val localProperties = new ThreadLocal[Properties]
 
   def initLocalProperties() {
-    localProperties.value = new Properties()
+    localProperties.set(new Properties())
   }
 
   def setLocalProperty(key: String, value: String) {
-    if (localProperties.value == null) {
-      localProperties.value = new Properties()
+    if (localProperties.get() == null) {
+      localProperties.set(new Properties())
     }
     if (value == null) {
-      localProperties.value.remove(key)
+      localProperties.get.remove(key)
     } else {
-      localProperties.value.setProperty(key, value)
+      localProperties.get.setProperty(key, value)
     }
   }
 
@@ -724,7 +723,7 @@ class SparkContext(
     logInfo("Starting job: " + callSite)
     val start = System.nanoTime
     val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler,
-      localProperties.value)
+      localProperties.get)
     logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
     rdd.doCheckpoint()
     result
@@ -807,7 +806,8 @@ class SparkContext(
     val callSite = Utils.formatSparkCallSite
     logInfo("Starting job: " + callSite)
     val start = System.nanoTime
-    val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.value)
+    val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout,
+      localProperties.get)
     logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
     result
   }
-- 
GitLab