Skip to content
Snippets Groups Projects
Commit 72a601ec authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Merge pull request #152 from rxin/repl

Propagate SparkContext local properties from spark-repl caller thread to the repl execution thread.
parents dd63c548 31929994
No related branches found
No related tags found
No related merge requests found
...@@ -280,6 +280,12 @@ class SparkContext( ...@@ -280,6 +280,12 @@ class SparkContext(
override protected def childValue(parent: Properties): Properties = new Properties(parent) override protected def childValue(parent: Properties): Properties = new Properties(parent)
} }
private[spark] def getLocalProperties(): Properties = localProperties.get()
private[spark] def setLocalProperties(props: Properties) {
localProperties.set(props)
}
def initLocalProperties() { def initLocalProperties() {
localProperties.set(new Properties()) localProperties.set(new Properties())
} }
......
...@@ -878,14 +878,21 @@ class SparkIMain(val settings: Settings, protected val out: PrintWriter) extends ...@@ -878,14 +878,21 @@ class SparkIMain(val settings: Settings, protected val out: PrintWriter) extends
(message, false) (message, false)
} }
} }
// Get a copy of the local properties from SparkContext, and set it later in the thread
// that triggers the execution. This is to make sure the caller of this function can pass
// the right thread local (inheritable) properties down into Spark.
val sc = org.apache.spark.repl.Main.interp.sparkContext
val props = if (sc != null) sc.getLocalProperties() else null
try { try {
val execution = lineManager.set(originalLine) { val execution = lineManager.set(originalLine) {
// MATEI: set the right SparkEnv for our SparkContext, because // MATEI: set the right SparkEnv for our SparkContext, because
// this execution will happen in a separate thread // this execution will happen in a separate thread
val sc = org.apache.spark.repl.Main.interp.sparkContext if (sc != null && sc.env != null) {
if (sc != null && sc.env != null)
SparkEnv.set(sc.env) SparkEnv.set(sc.env)
sc.setLocalProperties(props)
}
// Execute the line // Execute the line
lineRep call "$export" lineRep call "$export"
} }
......
...@@ -21,12 +21,14 @@ import java.io._ ...@@ -21,12 +21,14 @@ import java.io._
import java.net.URLClassLoader import java.net.URLClassLoader
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
import org.scalatest.FunSuite
import com.google.common.io.Files import com.google.common.io.Files
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
class ReplSuite extends FunSuite { class ReplSuite extends FunSuite {
def runInterpreter(master: String, input: String): String = { def runInterpreter(master: String, input: String): String = {
val in = new BufferedReader(new StringReader(input + "\n")) val in = new BufferedReader(new StringReader(input + "\n"))
val out = new StringWriter() val out = new StringWriter()
...@@ -64,6 +66,35 @@ class ReplSuite extends FunSuite { ...@@ -64,6 +66,35 @@ class ReplSuite extends FunSuite {
"Interpreter output contained '" + message + "':\n" + output) "Interpreter output contained '" + message + "':\n" + output)
} }
test("propagation of local properties") {
// A mock ILoop that doesn't install the SIGINT handler.
class ILoop(out: PrintWriter) extends SparkILoop(None, out, None) {
settings = new scala.tools.nsc.Settings
settings.usejavacp.value = true
org.apache.spark.repl.Main.interp = this
override def createInterpreter() {
intp = new SparkILoopInterpreter
intp.setContextClassLoader()
}
}
val out = new StringWriter()
val interp = new ILoop(new PrintWriter(out))
interp.sparkContext = new SparkContext("local", "repl-test")
interp.createInterpreter()
interp.intp.initialize()
interp.sparkContext.setLocalProperty("someKey", "someValue")
// Make sure the value we set in the caller to interpret is propagated in the thread that
// interprets the command.
interp.interpret("org.apache.spark.repl.Main.interp.sparkContext.getLocalProperty(\"someKey\")")
assert(out.toString.contains("someValue"))
interp.sparkContext.stop()
System.clearProperty("spark.driver.port")
System.clearProperty("spark.hostPort")
}
test ("simple foreach with accumulator") { test ("simple foreach with accumulator") {
val output = runInterpreter("local", """ val output = runInterpreter("local", """
val accum = sc.accumulator(0) val accum = sc.accumulator(0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment