diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 65788f4646d91118049bf4e30ac477f0ecb0f91d..53df599cf8121047a74d5b5fe208069e223559e7 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -8,7 +8,11 @@ package org.apache.spark.repl +import java.net.URL + +import scala.reflect.io.AbstractFile import scala.tools.nsc._ +import scala.tools.nsc.backend.JavaPlatform import scala.tools.nsc.interpreter._ import scala.tools.nsc.interpreter.{ Results => IR } @@ -22,11 +26,10 @@ import scala.tools.util.{ Javap } import scala.annotation.tailrec import scala.collection.mutable.ListBuffer import scala.concurrent.ops -import scala.tools.nsc.util.{ ClassPath, Exceptional, stringFromWriter, stringFromStream } +import scala.tools.nsc.util._ import scala.tools.nsc.interpreter._ -import scala.tools.nsc.io.{ File, Directory } +import scala.tools.nsc.io.{File, Directory} import scala.reflect.NameTransformer._ -import scala.tools.nsc.util.ScalaClassLoader import scala.tools.nsc.util.ScalaClassLoader._ import scala.tools.util._ import scala.language.{implicitConversions, existentials} @@ -711,22 +714,24 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, added = true addedClasspath = ClassPath.join(addedClasspath, f.path) totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath) + intp.addUrlsToClassPath(f.toURI.toURL) + sparkContext.addJar(f.toURI.toURL.getPath) } } - if (added) replay() } def addClasspath(arg: String): Unit = { val f = File(arg).normalize if (f.exists) { addedClasspath = ClassPath.join(addedClasspath, f.path) - val totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath) - echo("Added '%s'. Your new classpath is:\n\"%s\"".format(f.path, totalClasspath)) - replay() + intp.addUrlsToClassPath(f.toURI.toURL) + sparkContext.addJar(f.toURI.toURL.getPath) + echo("Added '%s'. Your new classpath is:\n\"%s\"".format(f.path, intp.global.classPath.asClasspathString)) } else echo("The path '" + f + "' doesn't seem to exist.") } + def powerCmd(): Result = { if (isReplPower) "Already in power mode." else enablePowerMode(false) diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 84b57cd2dc1af6b85b97e903acb13eadea60734d..6ddb6accd696b7ab09fd5682a7894334b1863b3f 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -7,11 +7,14 @@ package org.apache.spark.repl +import java.io.File + import scala.tools.nsc._ +import scala.tools.nsc.backend.JavaPlatform import scala.tools.nsc.interpreter._ import Predef.{ println => _, _ } -import util.stringFromWriter +import scala.tools.nsc.util.{MergedClassPath, stringFromWriter, ScalaClassLoader, stackTraceString} import scala.reflect.internal.util._ import java.net.URL import scala.sys.BooleanProp @@ -21,7 +24,6 @@ import reporters._ import symtab.Flags import scala.reflect.internal.Names import scala.tools.util.PathResolver -import scala.tools.nsc.util.ScalaClassLoader import ScalaClassLoader.URLClassLoader import scala.tools.nsc.util.Exceptional.unwrap import scala.collection.{ mutable, immutable } @@ -34,7 +36,6 @@ import scala.reflect.runtime.{ universe => ru } import scala.reflect.{ ClassTag, classTag } import scala.tools.reflect.StdRuntimeTags._ import scala.util.control.ControlThrowable -import util.stackTraceString import org.apache.spark.{Logging, HttpServer, SecurityManager, SparkConf} import org.apache.spark.util.Utils @@ -130,6 +131,9 @@ import org.apache.spark.util.Utils private var _classLoader: AbstractFileClassLoader = null // active classloader private val _compiler: Global = newCompiler(settings, reporter) // our private compiler + private trait ExposeAddUrl extends URLClassLoader { def addNewUrl(url: URL) = this.addURL(url) } + private var _runtimeClassLoader: URLClassLoader with ExposeAddUrl = null // wrapper exposing addURL + private val nextReqId = { var counter = 0 () => { counter += 1 ; counter } @@ -308,6 +312,57 @@ import org.apache.spark.util.Utils } } + /** + * Adds any specified jars to the compile and runtime classpaths. + * + * @note Currently only supports jars, not directories + * @param urls The list of items to add to the compile and runtime classpaths + */ + def addUrlsToClassPath(urls: URL*): Unit = { + new Run // Needed to force initialization of "something" to correctly load Scala classes from jars + urls.foreach(_runtimeClassLoader.addNewUrl) // Add jars/classes to runtime for execution + updateCompilerClassPath(urls: _*) // Add jars/classes to compile time for compiling + } + + protected def updateCompilerClassPath(urls: URL*): Unit = { + require(!global.forMSIL) // Only support JavaPlatform + + val platform = global.platform.asInstanceOf[JavaPlatform] + + val newClassPath = mergeUrlsIntoClassPath(platform, urls: _*) + + // NOTE: Must use reflection until this is exposed/fixed upstream in Scala + val fieldSetter = platform.getClass.getMethods + .find(_.getName.endsWith("currentClassPath_$eq")).get + fieldSetter.invoke(platform, Some(newClassPath)) + + // Reload all jars specified into our compiler + global.invalidateClassPathEntries(urls.map(_.getPath): _*) + } + + protected def mergeUrlsIntoClassPath(platform: JavaPlatform, urls: URL*): MergedClassPath[AbstractFile] = { + // Collect our new jars/directories and add them to the existing set of classpaths + val allClassPaths = ( + platform.classPath.asInstanceOf[MergedClassPath[AbstractFile]].entries ++ + urls.map(url => { + platform.classPath.context.newClassPath( + if (url.getProtocol == "file") { + val f = new File(url.getPath) + if (f.isDirectory) + io.AbstractFile.getDirectory(f) + else + io.AbstractFile.getFile(f) + } else { + io.AbstractFile.getURL(url) + } + ) + }) + ).distinct + + // Combine all of our classpaths (old and new) into one merged classpath + new MergedClassPath(allClassPaths, platform.classPath.context) + } + /** Parent classloader. Overridable. */ protected def parentClassLoader: ClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse( this.getClass.getClassLoader() ) @@ -356,7 +411,9 @@ import org.apache.spark.util.Utils private def makeClassLoader(): AbstractFileClassLoader = new TranslatingClassLoader(parentClassLoader match { case null => ScalaClassLoader fromURLs compilerClasspath - case p => new URLClassLoader(compilerClasspath, p) + case p => + _runtimeClassLoader = new URLClassLoader(compilerClasspath, p) with ExposeAddUrl + _runtimeClassLoader }) def getInterpreterClassLoader() = classLoader