From 284771efbef2d6b22212afd49dd62732a2cf52a8 Mon Sep 17 00:00:00 2001 From: Ye Xianjin <advancedxy@gmail.com> Date: Fri, 1 Aug 2014 00:34:39 -0700 Subject: [PATCH] [Spark 2557] fix LOCAL_N_REGEX in createTaskScheduler and make local-n and local-n-failures consistent [SPARK-2557](https://issues.apache.org/jira/browse/SPARK-2557) Author: Ye Xianjin <advancedxy@gmail.com> Closes #1464 from advancedxy/SPARK-2557 and squashes the following commits: d844d67 [Ye Xianjin] add local-*-n-failures, bad-local-n, bad-local-n-failures test case 3bbc668 [Ye Xianjin] fix LOCAL_N_REGEX regular expression and make local_n_failures accept * as all cores on the computer --- .../scala/org/apache/spark/SparkContext.scala | 10 +++++--- .../SparkContextSchedulerCreationSuite.scala | 23 +++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index f5a0549834..0e513568b0 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1452,9 +1452,9 @@ object SparkContext extends Logging { /** Creates a task scheduler based on a given master URL. Extracted for testing. */ private def createTaskScheduler(sc: SparkContext, master: String): TaskScheduler = { // Regular expression used for local[N] and local[*] master formats - val LOCAL_N_REGEX = """local\[([0-9\*]+)\]""".r + val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r // Regular expression for local[N, maxRetries], used in tests with failing tasks - val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r + val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r // Regular expression for simulating a Spark cluster of [N, cores, memory] locally val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r // Regular expression for connecting to Spark deploy clusters @@ -1484,8 +1484,12 @@ object SparkContext extends Logging { scheduler case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => + def localCpuCount = Runtime.getRuntime.availableProcessors() + // local[*, M] means the number of cores on the computer with M failures + // local[N, M] means exactly N threads with M failures + val threadCount = if (threads == "*") localCpuCount else threads.toInt val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true) - val backend = new LocalBackend(scheduler, threads.toInt) + val backend = new LocalBackend(scheduler, threadCount) scheduler.initialize(backend) scheduler diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 67e3be21c3..4b727e50db 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -68,6 +68,15 @@ class SparkContextSchedulerCreationSuite } } + test("local-*-n-failures") { + val sched = createTaskScheduler("local[* ,2]") + assert(sched.maxTaskFailures === 2) + sched.backend match { + case s: LocalBackend => assert(s.totalCores === Runtime.getRuntime.availableProcessors()) + case _ => fail() + } + } + test("local-n-failures") { val sched = createTaskScheduler("local[4, 2]") assert(sched.maxTaskFailures === 2) @@ -77,6 +86,20 @@ class SparkContextSchedulerCreationSuite } } + test("bad-local-n") { + val e = intercept[SparkException] { + createTaskScheduler("local[2*]") + } + assert(e.getMessage.contains("Could not parse Master URL")) + } + + test("bad-local-n-failures") { + val e = intercept[SparkException] { + createTaskScheduler("local[2*,4]") + } + assert(e.getMessage.contains("Could not parse Master URL")) + } + test("local-default-parallelism") { val defaultParallelism = System.getProperty("spark.default.parallelism") System.setProperty("spark.default.parallelism", "16") -- GitLab