diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 729f518b89c06013e2fb3213c61b9997b81387a5..dc5b25d845dc279e171ac4fb14adb663943a1620 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -272,8 +272,10 @@ class DAGScheduler(
     if (mapOutputTracker.has(shuffleDep.shuffleId)) {
       val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
       val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
-      for (i <- 0 until locs.size) stage.outputLocs(i) = List(locs(i))
-      stage.numAvailableOutputs = locs.size
+      for (i <- 0 until locs.size) {
+        stage.outputLocs(i) = Option(locs(i)).toList   // locs(i) will be null if missing
+      }
+      stage.numAvailableOutputs = locs.count(_ != null)
     } else {
       // Kind of ugly: need to register RDDs with the cache and map output tracker here
       // since we can't do it in the RDD constructor because # of partitions is unknown
@@ -373,25 +375,26 @@ class DAGScheduler(
           } else {
             def removeStage(stageId: Int) {
               // data structures based on Stage
-              stageIdToStage.get(stageId).foreach { s =>
-                if (running.contains(s)) {
+              for (stage <- stageIdToStage.get(stageId)) {
+                if (running.contains(stage)) {
                   logDebug("Removing running stage %d".format(stageId))
-                  running -= s
+                  running -= stage
+                }
+                stageToInfos -= stage
+                for ((k, v) <- shuffleToMapStage.find(_._2 == stage)) {
+                  shuffleToMapStage.remove(k)
                 }
-                stageToInfos -= s
-                shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleId =>
-                  shuffleToMapStage.remove(shuffleId))
-                if (pendingTasks.contains(s) && !pendingTasks(s).isEmpty) {
+                if (pendingTasks.contains(stage) && !pendingTasks(stage).isEmpty) {
                   logDebug("Removing pending status for stage %d".format(stageId))
                 }
-                pendingTasks -= s
-                if (waiting.contains(s)) {
+                pendingTasks -= stage
+                if (waiting.contains(stage)) {
                   logDebug("Removing stage %d from waiting set.".format(stageId))
-                  waiting -= s
+                  waiting -= stage
                 }
-                if (failed.contains(s)) {
+                if (failed.contains(stage)) {
                   logDebug("Removing stage %d from failed set.".format(stageId))
-                  failed -= s
+                  failed -= stage
                 }
               }
               // data structures based on StageId
diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
index ac3c86778d526034fc25465fe615d1e17e7cae32..f3fb64d87a2fd1df3b169d9c18505c6fd9fc59f0 100644
--- a/core/src/test/scala/org/apache/spark/FailureSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
@@ -81,6 +81,19 @@ class FailureSuite extends FunSuite with LocalSparkContext {
     FailureSuiteState.clear()
   }
 
+  // Run a map-reduce job in which the map stage always fails.
+  test("failure in a map stage") {
+    sc = new SparkContext("local", "test")
+    val data = sc.makeRDD(1 to 3).map(x => { throw new Exception; (x, x) }).groupByKey(3)
+    intercept[SparkException] {
+      data.collect()
+    }
+    // Make sure that running new jobs with the same map stage also fails
+    intercept[SparkException] {
+      data.collect()
+    }
+  }
+
   test("failure because task results are not serializable") {
     sc = new SparkContext("local[1,1]", "test")
     val results = sc.makeRDD(1 to 3).map(x => new NonSerializable)