diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 3c92c205ea978c58fd00b8ff86077bd2d9cbca2e..e51d274d338748cd48a162d39458326012b921cb 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -141,11 +141,6 @@ private[spark] class Executor(
     val tr = runningTasks.get(taskId)
     if (tr != null) {
       tr.kill()
-      // We remove the task also in the finally block in TaskRunner.run.
-      // The reason we need to remove it here is because killTask might be called before the task
-      // is even launched, and never reaching that finally block. ConcurrentHashMap's remove is
-      // idempotent.
-      runningTasks.remove(taskId)
     }
   }
 
@@ -167,6 +162,8 @@ private[spark] class Executor(
   class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
     extends Runnable {
 
+    object TaskKilledException extends Exception
+
     @volatile private var killed = false
     @volatile private var task: Task[Any] = _
 
@@ -200,9 +197,11 @@ private[spark] class Executor(
         // If this task has been killed before we deserialized it, let's quit now. Otherwise,
         // continue executing the task.
         if (killed) {
-          logInfo("Executor killed task " + taskId)
-          execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
-          return
+          // Throw an exception rather than returning, because returning within a try{} block
+          // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
+          // exception will be caught by the catch block, leading to an incorrect ExceptionFailure
+          // for the task.
+          throw TaskKilledException
         }
 
         attemptedTask = Some(task)
@@ -216,9 +215,7 @@ private[spark] class Executor(
 
         // If the task has been killed, let's fail it.
         if (task.killed) {
-          logInfo("Executor killed task " + taskId)
-          execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
-          return
+          throw TaskKilledException
         }
 
         val resultSer = SparkEnv.get.serializer.newInstance()
@@ -260,6 +257,11 @@ private[spark] class Executor(
           execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
         }
 
+        case TaskKilledException => {
+          logInfo("Executor killed task " + taskId)
+          execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
+        }
+
         case t: Throwable => {
           val serviceTime = (System.currentTimeMillis() - taskStart).toInt
           val metrics = attemptedTask.flatMap(t => t.metrics)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index c4ac8337c568ea8649df806ce0765084b5a82639..0c8ed6275991a4cd7f8def626244c1f73ed8cd28 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -285,7 +285,8 @@ private[spark] class TaskSchedulerImpl(
               }
             }
           case None =>
-            logInfo("Ignoring update from TID " + tid + " because its task set is gone")
+            logInfo("Ignoring update with state %s from TID %s because its task set is gone"
+              .format(state, tid))
         }
       } catch {
         case e: Exception => logError("Exception in statusUpdate", e)