diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index bd24cd19f22fad397596a32ed2b8a03107c5432a..670c8b4caadb5c254db9e32375743bd7cdea988b 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -74,13 +74,21 @@ class SparkConf(loadDefaults: Boolean) extends Serializable with Cloneable { this } - /** Set an environment variable to be used when launching executors for this application. */ + /** + * Set an environment variable to be used when launching executors for this application. + * These variables are stored as properties of the form spark.executorEnv.VAR_NAME + * (for example spark.executorEnv.PATH) but this method makes them easier to set. + */ def setExecutorEnv(variable: String, value: String): SparkConf = { settings("spark.executorEnv." + variable) = value this } - /** Set multiple environment variables to be used when launching executors. */ + /** + * Set multiple environment variables to be used when launching executors. + * These variables are stored as properties of the form spark.executorEnv.VAR_NAME + * (for example spark.executorEnv.PATH) but this method makes them easier to set. + */ def setExecutorEnv(variables: Seq[(String, String)]): SparkConf = { for ((k, v) <- variables) { setExecutorEnv(k, v) @@ -135,7 +143,7 @@ class SparkConf(loadDefaults: Boolean) extends Serializable with Cloneable { } /** Get all parameters as a list of pairs */ - def getAll: Seq[(String, String)] = settings.clone().toSeq + def getAll: Array[(String, String)] = settings.clone().toArray /** Get a parameter, falling back to a default if not set */ def getOrElse(k: String, defaultValue: String): String = { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 0567f7f43708d6dba6e1f239e1757f9b2f01ec86..c109ff930ca781ae33d1883833cc58b49f33380c 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -78,7 +78,7 @@ class SparkContext( * @param conf a [[org.apache.spark.SparkConf]] object specifying other Spark parameters */ def this(master: String, appName: String, conf: SparkConf) = - this(conf.setMaster(master).setAppName(appName)) + this(conf.clone().setMaster(master).setAppName(appName)) /** * Alternative constructor that allows setting common Spark properties directly diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 1f35f6f939d8e70c16ec3ebaff5d0b70df59c7bc..f1b95acf097810a9357f661f6bc4f9ca8578e106 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -28,6 +28,8 @@ Public classes: A broadcast variable that gets reused across tasks. - L{Accumulator<pyspark.accumulators.Accumulator>} An "add-only" shared variable that tasks can only add values to. + - L{SparkConf<pyspark.conf.SparkConf} + Configuration for a Spark application. - L{SparkFiles<pyspark.files.SparkFiles>} Access files shipped with jobs. - L{StorageLevel<pyspark.storagelevel.StorageLevel>} @@ -38,10 +40,11 @@ import os sys.path.insert(0, os.path.join(os.environ["SPARK_HOME"], "python/lib/py4j0.7.egg")) +from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.rdd import RDD from pyspark.files import SparkFiles from pyspark.storagelevel import StorageLevel -__all__ = ["SparkContext", "RDD", "SparkFiles", "StorageLevel"] +__all__ = ["SparkConf", "SparkContext", "RDD", "SparkFiles", "StorageLevel"] diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..56e615c28794923593276d8d816d61a122c43449 --- /dev/null +++ b/python/pyspark/conf.py @@ -0,0 +1,113 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +>>> from pyspark.conf import SparkConf +>>> from pyspark.context import SparkContext +>>> conf = SparkConf() +>>> conf.setMaster("local").setAppName("My app") +<pyspark.conf.SparkConf object at ...> +>>> conf.get("spark.master") +u'local' +>>> conf.get("spark.appName") +u'My app' +>>> sc = SparkContext(conf=conf) +>>> sc.master +u'local' +>>> sc.appName +u'My app' +>>> sc.sparkHome == None +True + +>>> conf = SparkConf() +>>> conf.setSparkHome("/path") +<pyspark.conf.SparkConf object at ...> +>>> conf.get("spark.home") +u'/path' +>>> conf.setExecutorEnv("VAR1", "value1") +<pyspark.conf.SparkConf object at ...> +>>> conf.setExecutorEnv(pairs = [("VAR3", "value3"), ("VAR4", "value4")]) +<pyspark.conf.SparkConf object at ...> +>>> conf.get("spark.executorEnv.VAR1") +u'value1' +>>> sorted(conf.getAll(), key=lambda p: p[0]) +[(u'spark.executorEnv.VAR1', u'value1'), (u'spark.executorEnv.VAR3', u'value3'), (u'spark.executorEnv.VAR4', u'value4'), (u'spark.home', u'/path')] +""" + + +class SparkConf(object): + def __init__(self, loadDefaults=False): + from pyspark.context import SparkContext + SparkContext._ensure_initialized() + self._jconf = SparkContext._jvm.SparkConf(loadDefaults) + + def set(self, key, value): + self._jconf.set(key, value) + return self + + def setMaster(self, value): + self._jconf.setMaster(value) + return self + + def setAppName(self, value): + self._jconf.setAppName(value) + return self + + def setSparkHome(self, value): + self._jconf.setSparkHome(value) + return self + + def setExecutorEnv(self, key=None, value=None, pairs=None): + if (key != None and pairs != None) or (key == None and pairs == None): + raise Exception("Either pass one key-value pair or a list of pairs") + elif key != None: + self._jconf.setExecutorEnv(key, value) + elif pairs != None: + for (k, v) in pairs: + self._jconf.setExecutorEnv(k, v) + return self + + def setAll(self, pairs): + for (k, v) in pairs: + self._jconf.set(k, v) + return self + + def get(self, key): + return self._jconf.get(key) + + def getOrElse(self, key, defaultValue): + return self._jconf.getOrElse(key, defaultValue) + + def getAll(self): + pairs = [] + for elem in self._jconf.getAll(): + pairs.append((elem._1(), elem._2())) + return pairs + + def contains(self, key): + return self._jconf.contains(key) + + +def _test(): + import doctest + (failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS) + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index c0645b2847b4db6c855539389f742265d92fbb45..97c1526afd7a905dbf3c2b4d641d3d24238185c3 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -24,6 +24,7 @@ from tempfile import NamedTemporaryFile from pyspark import accumulators from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast +from pyspark.conf import SparkConf from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer @@ -49,14 +50,14 @@ class SparkContext(object): _python_includes = None # zip and egg files that need to be added to PYTHONPATH - def __init__(self, master, jobName, sparkHome=None, pyFiles=None, - environment=None, batchSize=1024, serializer=PickleSerializer()): + def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, + environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None): """ Create a new SparkContext. @param master: Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - @param jobName: A name for your job, to display on the cluster web UI + @param appName: A name for your job, to display on the cluster web UI. @param sparkHome: Location where Spark is installed on cluster nodes. @param pyFiles: Collection of .zip or .py files to send to the cluster and add to PYTHONPATH. These can be paths on the local file @@ -67,6 +68,7 @@ class SparkContext(object): Java object. Set 1 to disable batching or -1 to use an unlimited batch size. @param serializer: The serializer for RDDs. + @param conf: A L{SparkConf} object setting Spark properties. >>> from pyspark.context import SparkContext @@ -79,10 +81,7 @@ class SparkContext(object): """ SparkContext._ensure_initialized(self) - self.master = master - self.jobName = jobName - self.sparkHome = sparkHome or None # None becomes null in Py4J - self.environment = environment or {} + self.conf = conf or SparkConf() self._batchSize = batchSize # -1 represents an unlimited batch size self._unbatched_serializer = serializer if batchSize == 1: @@ -91,10 +90,26 @@ class SparkContext(object): self.serializer = BatchedSerializer(self._unbatched_serializer, batchSize) + # Set parameters passed directly on our conf; these operations will be no-ops + # if the parameters were None + self.conf.setMaster(master) + self.conf.setAppName(appName) + self.conf.setSparkHome(sparkHome) + environment = environment or {} + for key, value in environment.iteritems(): + self.conf.setExecutorEnv(key, value) + + if not self.conf.contains("spark.master"): + raise Exception("A master URL must be set in your configuration") + if not self.conf.contains("spark.appName"): + raise Exception("An application name must be set in your configuration") + + self.master = self.conf.get("spark.master") + self.appName = self.conf.get("spark.appName") + self.sparkHome = self.conf.getOrElse("spark.home", None) + # Create the Java SparkContext through Py4J - empty_string_array = self._gateway.new_array(self._jvm.String, 0) - self._jsc = self._jvm.JavaSparkContext(master, jobName, sparkHome, - empty_string_array) + self._jsc = self._jvm.JavaSparkContext(self.conf._jconf) # Create a single Accumulator in Java that we'll send all our updates through; # they will be passed back to us through a TCP server @@ -105,6 +120,7 @@ class SparkContext(object): self._jvm.PythonAccumulatorParam(host, port)) self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') + # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have # been pickled, so it can determine which Java broadcast objects to @@ -143,8 +159,8 @@ class SparkContext(object): @classmethod def setSystemProperty(cls, key, value): """ - Set a system property, such as spark.executor.memory. This must be - invoked before instantiating SparkContext. + Set a Java system property, such as spark.executor.memory. This must + must be invoked before instantiating SparkContext. """ SparkContext._ensure_initialized() SparkContext._jvm.java.lang.System.setProperty(key, value) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index e615c1e9b63a46a857aa575561522410c2a6ce7e..128f078d12c1f59566aa02ca00fee27228c9dbab 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -60,6 +60,7 @@ def launch_gateway(): # Connect to the gateway gateway = JavaGateway(GatewayClient(port=port), auto_convert=False) # Import the classes used by PySpark + java_import(gateway.jvm, "org.apache.spark.SparkConf") java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "scala.Tuple2")