From 284771efbef2d6b22212afd49dd62732a2cf52a8 Mon Sep 17 00:00:00 2001
From: Ye Xianjin <advancedxy@gmail.com>
Date: Fri, 1 Aug 2014 00:34:39 -0700
Subject: [PATCH] [Spark 2557] fix LOCAL_N_REGEX in createTaskScheduler and
 make local-n and local-n-failures consistent

[SPARK-2557](https://issues.apache.org/jira/browse/SPARK-2557)

Author: Ye Xianjin <advancedxy@gmail.com>

Closes #1464 from advancedxy/SPARK-2557 and squashes the following commits:

d844d67 [Ye Xianjin] add local-*-n-failures, bad-local-n, bad-local-n-failures test case
3bbc668 [Ye Xianjin] fix LOCAL_N_REGEX regular expression and make local_n_failures accept * as all cores on the computer
---
 .../scala/org/apache/spark/SparkContext.scala | 10 +++++---
 .../SparkContextSchedulerCreationSuite.scala  | 23 +++++++++++++++++++
 2 files changed, 30 insertions(+), 3 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index f5a0549834..0e513568b0 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1452,9 +1452,9 @@ object SparkContext extends Logging {
   /** Creates a task scheduler based on a given master URL. Extracted for testing. */
   private def createTaskScheduler(sc: SparkContext, master: String): TaskScheduler = {
     // Regular expression used for local[N] and local[*] master formats
-    val LOCAL_N_REGEX = """local\[([0-9\*]+)\]""".r
+    val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r
     // Regular expression for local[N, maxRetries], used in tests with failing tasks
-    val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r
+    val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r
     // Regular expression for simulating a Spark cluster of [N, cores, memory] locally
     val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
     // Regular expression for connecting to Spark deploy clusters
@@ -1484,8 +1484,12 @@ object SparkContext extends Logging {
         scheduler
 
       case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
+        def localCpuCount = Runtime.getRuntime.availableProcessors()
+        // local[*, M] means the number of cores on the computer with M failures
+        // local[N, M] means exactly N threads with M failures
+        val threadCount = if (threads == "*") localCpuCount else threads.toInt
         val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true)
-        val backend = new LocalBackend(scheduler, threads.toInt)
+        val backend = new LocalBackend(scheduler, threadCount)
         scheduler.initialize(backend)
         scheduler
 
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
index 67e3be21c3..4b727e50db 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
@@ -68,6 +68,15 @@ class SparkContextSchedulerCreationSuite
     }
   }
 
+  test("local-*-n-failures") {
+    val sched = createTaskScheduler("local[* ,2]")
+    assert(sched.maxTaskFailures === 2)
+    sched.backend match {
+      case s: LocalBackend => assert(s.totalCores === Runtime.getRuntime.availableProcessors())
+      case _ => fail()
+    }
+  }
+
   test("local-n-failures") {
     val sched = createTaskScheduler("local[4, 2]")
     assert(sched.maxTaskFailures === 2)
@@ -77,6 +86,20 @@ class SparkContextSchedulerCreationSuite
     }
   }
 
+  test("bad-local-n") {
+    val e = intercept[SparkException] {
+      createTaskScheduler("local[2*]")
+    }
+    assert(e.getMessage.contains("Could not parse Master URL"))
+  }
+
+  test("bad-local-n-failures") {
+    val e = intercept[SparkException] {
+      createTaskScheduler("local[2*,4]")
+    }
+    assert(e.getMessage.contains("Could not parse Master URL"))
+  }
+
   test("local-default-parallelism") {
     val defaultParallelism = System.getProperty("spark.default.parallelism")
     System.setProperty("spark.default.parallelism", "16")
-- 
GitLab