diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index ee82d679935c0e794adaa077d275aef57faf1885..a1a1fb01426a0943d84120cb32e9cde48d682448 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -125,7 +125,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
     assert (partitionWriters == null);
     if (!records.hasNext()) {
       partitionLengths = new long[numPartitions];
-      shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+      shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null);
       mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
       return;
     }
@@ -155,9 +155,10 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
       writer.commitAndClose();
     }
 
-    partitionLengths =
-      writePartitionedFile(shuffleBlockResolver.getDataFile(shuffleId, mapId));
-    shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+    File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
+    File tmp = Utils.tempFileWith(output);
+    partitionLengths = writePartitionedFile(tmp);
+    shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
     mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
   }
 
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 6a0a89e81c321fe9db44d98e6971429876f7f284..744c3008ca50e59338b9e5acdf9950c6163ea587 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -41,7 +41,7 @@ import org.apache.spark.annotation.Private;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.io.CompressionCodec;
 import org.apache.spark.io.CompressionCodec$;
-import org.apache.spark.io.LZFCompressionCodec;
+import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.network.util.LimitedInputStream;
 import org.apache.spark.scheduler.MapStatus;
 import org.apache.spark.scheduler.MapStatus$;
@@ -53,7 +53,7 @@ import org.apache.spark.shuffle.ShuffleWriter;
 import org.apache.spark.storage.BlockManager;
 import org.apache.spark.storage.TimeTrackingOutputStream;
 import org.apache.spark.unsafe.Platform;
-import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
 
 @Private
 public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@@ -206,8 +206,10 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
     final SpillInfo[] spills = sorter.closeAndGetSpills();
     sorter = null;
     final long[] partitionLengths;
+    final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
+    final File tmp = Utils.tempFileWith(output);
     try {
-      partitionLengths = mergeSpills(spills);
+      partitionLengths = mergeSpills(spills, tmp);
     } finally {
       for (SpillInfo spill : spills) {
         if (spill.file.exists() && ! spill.file.delete()) {
@@ -215,7 +217,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
         }
       }
     }
-    shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+    shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
     mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
   }
 
@@ -248,8 +250,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
    *
    * @return the partition lengths in the merged file.
    */
-  private long[] mergeSpills(SpillInfo[] spills) throws IOException {
-    final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId);
+  private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException {
     final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true);
     final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
     final boolean fastMergeEnabled =
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
index cd253a78c2b19fb2815ac9d70a817e3475ff40ef..39fadd87835180e160268db45f733d95a42a51fc 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
@@ -21,13 +21,13 @@ import java.util.concurrent.ConcurrentLinkedQueue
 
 import scala.collection.JavaConverters._
 
-import org.apache.spark.{Logging, SparkConf, SparkEnv}
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
 import org.apache.spark.network.netty.SparkTransportConf
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.storage._
-import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
+import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
+import org.apache.spark.{Logging, SparkConf, SparkEnv}
 
 /** A group of writers for a ShuffleMapTask, one writer per reducer. */
 private[spark] trait ShuffleWriterGroup {
@@ -84,17 +84,8 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
         Array.tabulate[DiskBlockObjectWriter](numReducers) { bucketId =>
           val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
           val blockFile = blockManager.diskBlockManager.getFile(blockId)
-          // Because of previous failures, the shuffle file may already exist on this machine.
-          // If so, remove it.
-          if (blockFile.exists) {
-            if (blockFile.delete()) {
-              logInfo(s"Removed existing shuffle file $blockFile")
-            } else {
-              logWarning(s"Failed to remove existing shuffle file $blockFile")
-            }
-          }
-          blockManager.getDiskWriter(blockId, blockFile, serializerInstance, bufferSize,
-            writeMetrics)
+          val tmp = Utils.tempFileWith(blockFile)
+          blockManager.getDiskWriter(blockId, tmp, serializerInstance, bufferSize, writeMetrics)
         }
       }
       // Creating the file to write to and creating a disk writer both involve interacting with
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index 5e4c2b5d0a5c478314c0285b3e2d8191f9988a49..05b1eed7f3bef9ead03de6ee43ef83d683c7cff1 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -21,13 +21,12 @@ import java.io._
 
 import com.google.common.io.ByteStreams
 
-import org.apache.spark.{SparkConf, SparkEnv, Logging}
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
 import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
 import org.apache.spark.storage._
 import org.apache.spark.util.Utils
-
-import IndexShuffleBlockResolver.NOOP_REDUCE_ID
+import org.apache.spark.{SparkEnv, Logging, SparkConf}
 
 /**
  * Create and maintain the shuffle blocks' mapping between logic block and physical file location.
@@ -40,10 +39,13 @@ import IndexShuffleBlockResolver.NOOP_REDUCE_ID
  */
 // Note: Changes to the format in this file should be kept in sync with
 // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getSortBasedShuffleBlockData().
-private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleBlockResolver
+private[spark] class IndexShuffleBlockResolver(
+    conf: SparkConf,
+    _blockManager: BlockManager = null)
+  extends ShuffleBlockResolver
   with Logging {
 
-  private lazy val blockManager = SparkEnv.get.blockManager
+  private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager)
 
   private val transportConf = SparkTransportConf.fromSparkConf(conf)
 
@@ -74,14 +76,69 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
     }
   }
 
+  /**
+   * Check whether the given index and data files match each other.
+   * If so, return the partition lengths in the data file. Otherwise return null.
+   */
+  private def checkIndexAndDataFile(index: File, data: File, blocks: Int): Array[Long] = {
+    // the index file should have `block + 1` longs as offset.
+    if (index.length() != (blocks + 1) * 8) {
+      return null
+    }
+    val lengths = new Array[Long](blocks)
+    // Read the lengths of blocks
+    val in = try {
+      new DataInputStream(new BufferedInputStream(new FileInputStream(index)))
+    } catch {
+      case e: IOException =>
+        return null
+    }
+    try {
+      // Convert the offsets into lengths of each block
+      var offset = in.readLong()
+      if (offset != 0L) {
+        return null
+      }
+      var i = 0
+      while (i < blocks) {
+        val off = in.readLong()
+        lengths(i) = off - offset
+        offset = off
+        i += 1
+      }
+    } catch {
+      case e: IOException =>
+        return null
+    } finally {
+      in.close()
+    }
+
+    // the size of data file should match with index file
+    if (data.length() == lengths.sum) {
+      lengths
+    } else {
+      null
+    }
+  }
+
   /**
    * Write an index file with the offsets of each block, plus a final offset at the end for the
    * end of the output file. This will be used by getBlockData to figure out where each block
    * begins and ends.
+   *
+   * It will commit the data and index file as an atomic operation, use the existing ones, or
+   * replace them with new ones.
+   *
+   * Note: the `lengths` will be updated to match the existing index file if use the existing ones.
    * */
-  def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = {
+  def writeIndexFileAndCommit(
+      shuffleId: Int,
+      mapId: Int,
+      lengths: Array[Long],
+      dataTmp: File): Unit = {
     val indexFile = getIndexFile(shuffleId, mapId)
-    val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile)))
+    val indexTmp = Utils.tempFileWith(indexFile)
+    val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp)))
     Utils.tryWithSafeFinally {
       // We take in lengths of each block, need to convert it to offsets.
       var offset = 0L
@@ -93,6 +150,37 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
     } {
       out.close()
     }
+
+    val dataFile = getDataFile(shuffleId, mapId)
+    // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure
+    // the following check and rename are atomic.
+    synchronized {
+      val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length)
+      if (existingLengths != null) {
+        // Another attempt for the same task has already written our map outputs successfully,
+        // so just use the existing partition lengths and delete our temporary map outputs.
+        System.arraycopy(existingLengths, 0, lengths, 0, lengths.length)
+        if (dataTmp != null && dataTmp.exists()) {
+          dataTmp.delete()
+        }
+        indexTmp.delete()
+      } else {
+        // This is the first successful attempt in writing the map outputs for this task,
+        // so override any existing index and data files with the ones we wrote.
+        if (indexFile.exists()) {
+          indexFile.delete()
+        }
+        if (dataFile.exists()) {
+          dataFile.delete()
+        }
+        if (!indexTmp.renameTo(indexFile)) {
+          throw new IOException("fail to rename file " + indexTmp + " to " + indexFile)
+        }
+        if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) {
+          throw new IOException("fail to rename file " + dataTmp + " to " + dataFile)
+        }
+      }
+    }
   }
 
   override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 41df70c602c30f83b8558605d4d97aa3ffda40c1..412bf70000da7ac817978cef65e1218e8df2bc2d 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.shuffle.hash
 
+import java.io.IOException
+
 import org.apache.spark._
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.scheduler.MapStatus
@@ -106,6 +108,29 @@ private[spark] class HashShuffleWriter[K, V](
       writer.commitAndClose()
       writer.fileSegment().length
     }
+    // rename all shuffle files to final paths
+    // Note: there is only one ShuffleBlockResolver in executor
+    shuffleBlockResolver.synchronized {
+      shuffle.writers.zipWithIndex.foreach { case (writer, i) =>
+        val output = blockManager.diskBlockManager.getFile(writer.blockId)
+        if (sizes(i) > 0) {
+          if (output.exists()) {
+            // Use length of existing file and delete our own temporary one
+            sizes(i) = output.length()
+            writer.file.delete()
+          } else {
+            // Commit by renaming our temporary file to something the fetcher expects
+            if (!writer.file.renameTo(output)) {
+              throw new IOException(s"fail to rename ${writer.file} to $output")
+            }
+          }
+        } else {
+          if (output.exists()) {
+            output.delete()
+          }
+        }
+      }
+    }
     MapStatus(blockManager.shuffleServerId, sizes)
   }
 
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 808317b017a0f04420ec3b827a56f7e5f87d31d1..f83cf8859e581e5c842457fa1cb569f7fcade58a 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -20,8 +20,9 @@ package org.apache.spark.shuffle.sort
 import org.apache.spark._
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.scheduler.MapStatus
-import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
+import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter}
 import org.apache.spark.storage.ShuffleBlockId
+import org.apache.spark.util.Utils
 import org.apache.spark.util.collection.ExternalSorter
 
 private[spark] class SortShuffleWriter[K, V, C](
@@ -65,11 +66,11 @@ private[spark] class SortShuffleWriter[K, V, C](
     // Don't bother including the time to open the merged output file in the shuffle write time,
     // because it just opens a single file, so is typically too fast to measure accurately
     // (see SPARK-3570).
-    val outputFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
+    val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
+    val tmp = Utils.tempFileWith(output)
     val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
-    val partitionLengths = sorter.writePartitionedFile(blockId, outputFile)
-    shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
-
+    val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
+    shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
     mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
   }
 
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index c374b93766225c1b0e902c02b02832dc5945d421..661c706af32b1a9dd7c7e29ba9b8dd7e70b9658f 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -21,10 +21,10 @@ import java.io._
 import java.nio.{ByteBuffer, MappedByteBuffer}
 
 import scala.collection.mutable.{ArrayBuffer, HashMap}
-import scala.concurrent.{ExecutionContext, Await, Future}
 import scala.concurrent.duration._
-import scala.util.control.NonFatal
+import scala.concurrent.{Await, ExecutionContext, Future}
 import scala.util.Random
+import scala.util.control.NonFatal
 
 import sun.nio.ch.DirectBuffer
 
@@ -38,9 +38,8 @@ import org.apache.spark.network.netty.SparkTransportConf
 import org.apache.spark.network.shuffle.ExternalShuffleClient
 import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
 import org.apache.spark.rpc.RpcEnv
-import org.apache.spark.serializer.{SerializerInstance, Serializer}
+import org.apache.spark.serializer.{Serializer, SerializerInstance}
 import org.apache.spark.shuffle.ShuffleManager
-import org.apache.spark.shuffle.hash.HashShuffleManager
 import org.apache.spark.util._
 
 private[spark] sealed trait BlockValues
@@ -660,7 +659,7 @@ private[spark] class BlockManager(
     val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
     val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
     new DiskBlockObjectWriter(file, serializerInstance, bufferSize, compressStream,
-      syncWrites, writeMetrics)
+      syncWrites, writeMetrics, blockId)
   }
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index 80d426fadc65e68990f282e54e0e57fba4423f88..e2dd80f24393035be517bdd45f423524fd17ca7c 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -34,14 +34,15 @@ import org.apache.spark.util.Utils
  * reopened again.
  */
 private[spark] class DiskBlockObjectWriter(
-    file: File,
+    val file: File,
     serializerInstance: SerializerInstance,
     bufferSize: Int,
     compressStream: OutputStream => OutputStream,
     syncWrites: Boolean,
     // These write metrics concurrently shared with other active DiskBlockObjectWriters who
     // are themselves performing writes. All updates must be relative.
-    writeMetrics: ShuffleWriteMetrics)
+    writeMetrics: ShuffleWriteMetrics,
+    val blockId: BlockId = null)
   extends OutputStream
   with Logging {
 
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 316c194ff3454552ec0f7474bed6e1742adb79b7..1b3acb8ef7f51df2f0fb8fecd59f48cdc0081013 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -21,8 +21,8 @@ import java.io._
 import java.lang.management.ManagementFactory
 import java.net._
 import java.nio.ByteBuffer
-import java.util.{Properties, Locale, Random, UUID}
 import java.util.concurrent._
+import java.util.{Locale, Properties, Random, UUID}
 import javax.net.ssl.HttpsURLConnection
 
 import scala.collection.JavaConverters._
@@ -30,7 +30,7 @@ import scala.collection.Map
 import scala.collection.mutable.ArrayBuffer
 import scala.io.Source
 import scala.reflect.ClassTag
-import scala.util.{Failure, Success, Try}
+import scala.util.Try
 import scala.util.control.{ControlThrowable, NonFatal}
 
 import com.google.common.io.{ByteStreams, Files}
@@ -42,7 +42,6 @@ import org.apache.hadoop.security.UserGroupInformation
 import org.apache.log4j.PropertyConfigurator
 import org.eclipse.jetty.util.MultiException
 import org.json4s._
-
 import tachyon.TachyonURI
 import tachyon.client.{TachyonFS, TachyonFile}
 
@@ -2169,6 +2168,13 @@ private[spark] object Utils extends Logging {
     val resource = createResource
     try f.apply(resource) finally resource.close()
   }
+
+  /**
+   * Returns a path of temporary file which is in the same directory with `path`.
+   */
+  def tempFileWith(path: File): File = {
+    new File(path.getAbsolutePath + "." + UUID.randomUUID())
+  }
 }
 
 /**
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index bd6844d045cad50baaaee78e501caa7102f0bf2c..2440139ac95e995d10ea62c34e6a5da2a621d2de 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -638,7 +638,6 @@ private[spark] class ExternalSorter[K, V, C](
    * called by the SortShuffleWriter.
    *
    * @param blockId block ID to write to. The index file will be blockId.name + ".index".
-   * @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
    * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
    */
   def writePartitionedFile(
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..0b19861fc41eedd5492eaf558735b194432bb75a
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import java.io.{File, FileInputStream, FileOutputStream}
+
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.Matchers._
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.mockito.{Mock, MockitoAnnotations}
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.storage._
+import org.apache.spark.util.Utils
+import org.apache.spark.{SparkConf, SparkFunSuite}
+
+
+class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEach {
+
+  @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _
+  @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _
+
+  private var tempDir: File = _
+  private val conf: SparkConf = new SparkConf(loadDefaults = false)
+
+  override def beforeEach(): Unit = {
+    tempDir = Utils.createTempDir()
+    MockitoAnnotations.initMocks(this)
+
+    when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
+    when(diskBlockManager.getFile(any[BlockId])).thenAnswer(
+      new Answer[File] {
+        override def answer(invocation: InvocationOnMock): File = {
+          new File(tempDir, invocation.getArguments.head.toString)
+        }
+      })
+  }
+
+  override def afterEach(): Unit = {
+    Utils.deleteRecursively(tempDir)
+  }
+
+  test("commit shuffle files multiple times") {
+    val lengths = Array[Long](10, 0, 20)
+    val resolver = new IndexShuffleBlockResolver(conf, blockManager)
+    val dataTmp = File.createTempFile("shuffle", null, tempDir)
+    val out = new FileOutputStream(dataTmp)
+    out.write(new Array[Byte](30))
+    out.close()
+    resolver.writeIndexFileAndCommit(1, 2, lengths, dataTmp)
+
+    val dataFile = resolver.getDataFile(1, 2)
+    assert(dataFile.exists())
+    assert(dataFile.length() === 30)
+    assert(!dataTmp.exists())
+
+    val dataTmp2 = File.createTempFile("shuffle", null, tempDir)
+    val out2 = new FileOutputStream(dataTmp2)
+    val lengths2 = new Array[Long](3)
+    out2.write(Array[Byte](1))
+    out2.write(new Array[Byte](29))
+    out2.close()
+    resolver.writeIndexFileAndCommit(1, 2, lengths2, dataTmp2)
+    assert(lengths2.toSeq === lengths.toSeq)
+    assert(dataFile.exists())
+    assert(dataFile.length() === 30)
+    assert(!dataTmp2.exists())
+
+    // The dataFile should be the previous one
+    val in = new FileInputStream(dataFile)
+    val firstByte = new Array[Byte](1)
+    in.read(firstByte)
+    assert(firstByte(0) === 0)
+
+    // remove data file
+    dataFile.delete()
+
+    val dataTmp3 = File.createTempFile("shuffle", null, tempDir)
+    val out3 = new FileOutputStream(dataTmp3)
+    val lengths3 = Array[Long](10, 10, 15)
+    out3.write(Array[Byte](2))
+    out3.write(new Array[Byte](34))
+    out3.close()
+    resolver.writeIndexFileAndCommit(1, 2, lengths3, dataTmp3)
+    assert(lengths3.toSeq != lengths.toSeq)
+    assert(dataFile.exists())
+    assert(dataFile.length() === 35)
+    assert(!dataTmp2.exists())
+
+    // The dataFile should be the previous one
+    val in2 = new FileInputStream(dataFile)
+    val firstByte2 = new Array[Byte](1)
+    in2.read(firstByte2)
+    assert(firstByte2(0) === 2)
+  }
+}
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 0e0eca515afc1f5a381e0c6353739392f7df9847..bc85918c59aabfbb3edcb94fc93294628497a5d3 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -130,7 +130,8 @@ public class UnsafeShuffleWriterSuite {
           (Integer) args[3],
           new CompressStream(),
           false,
-          (ShuffleWriteMetrics) args[4]
+          (ShuffleWriteMetrics) args[4],
+          (BlockId) args[0]
         );
       }
     });
@@ -169,9 +170,13 @@ public class UnsafeShuffleWriterSuite {
       @Override
       public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
         partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
+        File tmp = (File) invocationOnMock.getArguments()[3];
+        mergedOutputFile.delete();
+        tmp.renameTo(mergedOutputFile);
         return null;
       }
-    }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class));
+    }).when(shuffleBlockResolver)
+      .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class));
 
     when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
       new Answer<Tuple2<TempShuffleBlockId, File>>() {
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 3bca790f3087072c00b43d9afbcb9bb5be9f0678..d87a1d2a56d99f0db60bbd7e04bfce28a3f4b118 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -117,7 +117,8 @@ public abstract class AbstractBytesToBytesMapSuite {
           (Integer) args[3],
           new CompressStream(),
           false,
-          (ShuffleWriteMetrics) args[4]
+          (ShuffleWriteMetrics) args[4],
+          (BlockId) args[0]
         );
       }
     });
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 11c3a7be388759048b4b5f795077dfe762f24d5c..a1c9f6fab8e6533b4ae50c89ae45c999484449c2 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -130,7 +130,8 @@ public class UnsafeExternalSorterSuite {
           (Integer) args[3],
           new CompressStream(),
           false,
-          (ShuffleWriteMetrics) args[4]
+          (ShuffleWriteMetrics) args[4],
+          (BlockId) args[0]
         );
       }
     });
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 4a0877d86f2c668788ea9b55c74db1247621f570..0de10ae485378ab241a12f8bc3d65c9d52b59354 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -17,12 +17,16 @@
 
 package org.apache.spark
 
+import java.util.concurrent.{Callable, Executors, ExecutorService, CyclicBarrier}
+
 import org.scalatest.Matchers
 
 import org.apache.spark.ShuffleSuite.NonJavaSerializableClass
+import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD}
-import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
+import org.apache.spark.scheduler.{MyRDD, MapStatus, SparkListener, SparkListenerTaskEnd}
 import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.shuffle.ShuffleWriter
 import org.apache.spark.storage.{ShuffleDataBlockId, ShuffleBlockId}
 import org.apache.spark.util.MutablePair
 
@@ -317,6 +321,107 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
     assert(metrics.bytesWritten === metrics.byresRead)
     assert(metrics.bytesWritten > 0)
   }
+
+  test("multiple simultaneous attempts for one task (SPARK-8029)") {
+    sc = new SparkContext("local", "test", conf)
+    val mapTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+    val manager = sc.env.shuffleManager
+
+    val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0L)
+    val metricsSystem = sc.env.metricsSystem
+    val shuffleMapRdd = new MyRDD(sc, 1, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1))
+    val shuffleHandle = manager.registerShuffle(0, 1, shuffleDep)
+
+    // first attempt -- its successful
+    val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
+      new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, metricsSystem,
+        InternalAccumulator.create(sc)))
+    val data1 = (1 to 10).map { x => x -> x}
+
+    // second attempt -- also successful.  We'll write out different data,
+    // just to simulate the fact that the records may get written differently
+    // depending on what gets spilled, what gets combined, etc.
+    val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0,
+      new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, metricsSystem,
+        InternalAccumulator.create(sc)))
+    val data2 = (11 to 20).map { x => x -> x}
+
+    // interleave writes of both attempts -- we want to test that both attempts can occur
+    // simultaneously, and everything is still OK
+
+    def writeAndClose(
+      writer: ShuffleWriter[Int, Int])(
+      iter: Iterator[(Int, Int)]): Option[MapStatus] = {
+      val files = writer.write(iter)
+      writer.stop(true)
+    }
+    val interleaver = new InterleaveIterators(
+      data1, writeAndClose(writer1), data2, writeAndClose(writer2))
+    val (mapOutput1, mapOutput2) = interleaver.run()
+
+    // check that we can read the map output and it has the right data
+    assert(mapOutput1.isDefined)
+    assert(mapOutput2.isDefined)
+    assert(mapOutput1.get.location === mapOutput2.get.location)
+    assert(mapOutput1.get.getSizeForBlock(0) === mapOutput1.get.getSizeForBlock(0))
+
+    // register one of the map outputs -- doesn't matter which one
+    mapOutput1.foreach { case mapStatus =>
+      mapTrackerMaster.registerMapOutputs(0, Array(mapStatus))
+    }
+
+    val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,
+      new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, metricsSystem,
+        InternalAccumulator.create(sc)))
+    val readData = reader.read().toIndexedSeq
+    assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq)
+
+    manager.unregisterShuffle(0)
+  }
+}
+
+/**
+ * Utility to help tests make sure that we can process two different iterators simultaneously
+ * in different threads.  This makes sure that in your test, you don't completely process data1 with
+ * f1 before processing data2 with f2 (or vice versa).  It adds a barrier so that the functions only
+ * process one element, before pausing to wait for the other function to "catch up".
+ */
+class InterleaveIterators[T, R](
+  data1: Seq[T],
+  f1: Iterator[T] => R,
+  data2: Seq[T],
+  f2: Iterator[T] => R) {
+
+  require(data1.size == data2.size)
+
+  val barrier = new CyclicBarrier(2)
+  class BarrierIterator[E](id: Int, sub: Iterator[E]) extends Iterator[E] {
+    def hasNext: Boolean = sub.hasNext
+
+    def next: E = {
+      barrier.await()
+      sub.next()
+    }
+  }
+
+  val c1 = new Callable[R] {
+    override def call(): R = f1(new BarrierIterator(1, data1.iterator))
+  }
+  val c2 = new Callable[R] {
+    override def call(): R = f2(new BarrierIterator(2, data2.iterator))
+  }
+
+  val e: ExecutorService = Executors.newFixedThreadPool(2)
+
+  def run(): (R, R) = {
+    val future1 = e.submit(c1)
+    val future2 = e.submit(c2)
+    val r1 = future1.get()
+    val r2 = future2.get()
+    e.shutdown()
+    (r1, r2)
+  }
 }
 
 object ShuffleSuite {
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index b92a302806f764d10306f5c55c1f6c9d0506c869..d3b1b2b620b4d38183332723dc802629a2455872 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -68,6 +68,17 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
     when(dependency.serializer).thenReturn(Some(new JavaSerializer(conf)))
     when(taskContext.taskMetrics()).thenReturn(taskMetrics)
     when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
+    doAnswer(new Answer[Void] {
+      def answer(invocationOnMock: InvocationOnMock): Void = {
+        val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File]
+        if (tmp != null) {
+          outputFile.delete
+          tmp.renameTo(outputFile)
+        }
+        null
+      }
+    }).when(blockResolver)
+      .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))
     when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
     when(blockManager.getDiskWriter(
       any[BlockId],
@@ -84,7 +95,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
           args(3).asInstanceOf[Int],
           compressStream = identity,
           syncWrites = false,
-          args(4).asInstanceOf[ShuffleWriteMetrics]
+          args(4).asInstanceOf[ShuffleWriteMetrics],
+          blockId = args(0).asInstanceOf[BlockId]
         )
       }
     })