diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 0bce531aaba3e68cf6c7c7984d92d72d01928de2..dd8e4ac66dc666c8f5e46333fa98747ebccc17be 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -35,7 +35,7 @@ import org.apache.spark.metrics.MetricsSystem
 import org.apache.spark.network.ConnectionManager
 import org.apache.spark.scheduler.LiveListenerBus
 import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle.ShuffleManager
+import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
 import org.apache.spark.storage._
 import org.apache.spark.util.{AkkaUtils, Utils}
 
@@ -66,12 +66,9 @@ class SparkEnv (
     val httpFileServer: HttpFileServer,
     val sparkFilesDir: String,
     val metricsSystem: MetricsSystem,
+    val shuffleMemoryManager: ShuffleMemoryManager,
     val conf: SparkConf) extends Logging {
 
-  // A mapping of thread ID to amount of memory, in bytes, used for shuffle aggregations
-  // All accesses should be manually synchronized
-  val shuffleMemoryMap = mutable.HashMap[Long, Long]()
-
   private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
 
   // A general, soft-reference map for metadata needed during HadoopRDD split computation
@@ -252,6 +249,8 @@ object SparkEnv extends Logging {
     val shuffleManager = instantiateClass[ShuffleManager](
       "spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager")
 
+    val shuffleMemoryManager = new ShuffleMemoryManager(conf)
+
     // Warn about deprecated spark.cache.class property
     if (conf.contains("spark.cache.class")) {
       logWarning("The spark.cache.class property is no longer being used! Specify storage " +
@@ -273,6 +272,7 @@ object SparkEnv extends Logging {
       httpFileServer,
       sparkFilesDir,
       metricsSystem,
+      shuffleMemoryManager,
       conf)
   }
 
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 1bb1b4aae91bb1155f998bef83dac19d7410c0a0..c2b9c660ddaecaf14833ee453bd22e0e12a13479 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -276,10 +276,7 @@ private[spark] class Executor(
         }
       } finally {
         // Release memory used by this thread for shuffles
-        val shuffleMemoryMap = env.shuffleMemoryMap
-        shuffleMemoryMap.synchronized {
-          shuffleMemoryMap.remove(Thread.currentThread().getId)
-        }
+        env.shuffleMemoryManager.releaseMemoryForThisThread()
         // Release memory used by this thread for unrolling blocks
         env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
         runningTasks.remove(taskId)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
new file mode 100644
index 0000000000000000000000000000000000000000..ee91a368b76eadab6e402ae5b33ea00a3a9aeb97
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import scala.collection.mutable
+
+import org.apache.spark.{Logging, SparkException, SparkConf}
+
+/**
+ * Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling
+ * collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory
+ * from this pool and release it as it spills data out. When a task ends, all its memory will be
+ * released by the Executor.
+ *
+ * This class tries to ensure that each thread gets a reasonable share of memory, instead of some
+ * thread ramping up to a large amount first and then causing others to spill to disk repeatedly.
+ * If there are N threads, it ensures that each thread can acquire at least 1 / 2N of the memory
+ * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the
+ * set of active threads and redo the calculations of 1 / 2N and 1 / N in waiting threads whenever
+ * this set changes. This is all done by synchronizing access on "this" to mutate state and using
+ * wait() and notifyAll() to signal changes.
+ */
+private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
+  private val threadMemory = new mutable.HashMap[Long, Long]()  // threadId -> memory bytes
+
+  def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf))
+
+  /**
+   * Try to acquire up to numBytes memory for the current thread, and return the number of bytes
+   * obtained, or 0 if none can be allocated. This call may block until there is enough free memory
+   * in some situations, to make sure each thread has a chance to ramp up to at least 1 / 2N of the
+   * total memory pool (where N is the # of active threads) before it is forced to spill. This can
+   * happen if the number of threads increases but an older thread had a lot of memory already.
+   */
+  def tryToAcquire(numBytes: Long): Long = synchronized {
+    val threadId = Thread.currentThread().getId
+    assert(numBytes > 0, "invalid number of bytes requested: " + numBytes)
+
+    // Add this thread to the threadMemory map just so we can keep an accurate count of the number
+    // of active threads, to let other threads ramp down their memory in calls to tryToAcquire
+    if (!threadMemory.contains(threadId)) {
+      threadMemory(threadId) = 0L
+      notifyAll()  // Will later cause waiting threads to wake up and check numThreads again
+    }
+
+    // Keep looping until we're either sure that we don't want to grant this request (because this
+    // thread would have more than 1 / numActiveThreads of the memory) or we have enough free
+    // memory to give it (we always let each thread get at least 1 / (2 * numActiveThreads)).
+    while (true) {
+      val numActiveThreads = threadMemory.keys.size
+      val curMem = threadMemory(threadId)
+      val freeMemory = maxMemory - threadMemory.values.sum
+
+      // How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads
+      val maxToGrant = math.min(numBytes, (maxMemory / numActiveThreads) - curMem)
+
+      if (curMem < maxMemory / (2 * numActiveThreads)) {
+        // We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking;
+        // if we can't give it this much now, wait for other threads to free up memory
+        // (this happens if older threads allocated lots of memory before N grew)
+        if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) {
+          val toGrant = math.min(maxToGrant, freeMemory)
+          threadMemory(threadId) += toGrant
+          return toGrant
+        } else {
+          logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free")
+          wait()
+        }
+      } else {
+        // Only give it as much memory as is free, which might be none if it reached 1 / numThreads
+        val toGrant = math.min(maxToGrant, freeMemory)
+        threadMemory(threadId) += toGrant
+        return toGrant
+      }
+    }
+    0L  // Never reached
+  }
+
+  /** Release numBytes bytes for the current thread. */
+  def release(numBytes: Long): Unit = synchronized {
+    val threadId = Thread.currentThread().getId
+    val curMem = threadMemory.getOrElse(threadId, 0L)
+    if (curMem < numBytes) {
+      throw new SparkException(
+        s"Internal error: release called on ${numBytes} bytes but thread only has ${curMem}")
+    }
+    threadMemory(threadId) -= numBytes
+    notifyAll()  // Notify waiters who locked "this" in tryToAcquire that memory has been freed
+  }
+
+  /** Release all memory for the current thread and mark it as inactive (e.g. when a task ends). */
+  def releaseMemoryForThisThread(): Unit = synchronized {
+    val threadId = Thread.currentThread().getId
+    threadMemory.remove(threadId)
+    notifyAll()  // Notify waiters who locked "this" in tryToAcquire that memory has been freed
+  }
+}
+
+private object ShuffleMemoryManager {
+  /**
+   * Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction
+   * of the memory pool and a safety factor since collections can sometimes grow bigger than
+   * the size we target before we estimate their sizes again.
+   */
+  def getMaxMemory(conf: SparkConf): Long = {
+    val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2)
+    val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8)
+    (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
+  }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 1f7d2dc838ebce62173f6b80ab8f55f2735018fb..cc0423856cefb2602db356644c8cb1b05f9a32b2 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -71,13 +71,7 @@ class ExternalAppendOnlyMap[K, V, C](
   private val spilledMaps = new ArrayBuffer[DiskMapIterator]
   private val sparkConf = SparkEnv.get.conf
   private val diskBlockManager = blockManager.diskBlockManager
-
-  // Collective memory threshold shared across all running tasks
-  private val maxMemoryThreshold = {
-    val memoryFraction = sparkConf.getDouble("spark.shuffle.memoryFraction", 0.2)
-    val safetyFraction = sparkConf.getDouble("spark.shuffle.safetyFraction", 0.8)
-    (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
-  }
+  private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
 
   // Number of pairs inserted since last spill; note that we count them even if a value is merged
   // with a previous key in case we're doing something like groupBy where the result grows
@@ -140,28 +134,15 @@ class ExternalAppendOnlyMap[K, V, C](
       if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
           currentMap.estimateSize() >= myMemoryThreshold)
       {
-        val currentSize = currentMap.estimateSize()
-        var shouldSpill = false
-        val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
-
-        // Atomically check whether there is sufficient memory in the global pool for
-        // this map to grow and, if possible, allocate the required amount
-        shuffleMemoryMap.synchronized {
-          val threadId = Thread.currentThread().getId
-          val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId)
-          val availableMemory = maxMemoryThreshold -
-            (shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L))
-
-          // Try to allocate at least 2x more memory, otherwise spill
-          shouldSpill = availableMemory < currentSize * 2
-          if (!shouldSpill) {
-            shuffleMemoryMap(threadId) = currentSize * 2
-            myMemoryThreshold = currentSize * 2
-          }
-        }
-        // Do not synchronize spills
-        if (shouldSpill) {
-          spill(currentSize)
+        // Claim up to double our current memory from the shuffle memory pool
+        val currentMemory = currentMap.estimateSize()
+        val amountToRequest = 2 * currentMemory - myMemoryThreshold
+        val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
+        myMemoryThreshold += granted
+        if (myMemoryThreshold <= currentMemory) {
+          // We were granted too little memory to grow further (either tryToAcquire returned 0,
+          // or we already had more memory than myMemoryThreshold); spill the current collection
+          spill(currentMemory)  // Will also release memory back to ShuffleMemoryManager
         }
       }
       currentMap.changeValue(curEntry._1, update)
@@ -245,12 +226,9 @@ class ExternalAppendOnlyMap[K, V, C](
     currentMap = new SizeTrackingAppendOnlyMap[K, C]
     spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))
 
-    // Reset the amount of shuffle memory used by this map in the global pool
-    val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
-    shuffleMemoryMap.synchronized {
-      shuffleMemoryMap(Thread.currentThread().getId) = 0
-    }
-    myMemoryThreshold = 0
+    // Release our memory back to the shuffle pool so that other threads can grab it
+    shuffleMemoryManager.release(myMemoryThreshold)
+    myMemoryThreshold = 0L
 
     elementsRead = 0
     _memoryBytesSpilled += mapSize
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index b04c50bd3e196b990d837bdfdffd6caa629826e4..101c83b264f6345ba63300973660e1ea74bdc841 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -78,6 +78,7 @@ private[spark] class ExternalSorter[K, V, C](
 
   private val blockManager = SparkEnv.get.blockManager
   private val diskBlockManager = blockManager.diskBlockManager
+  private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
   private val ser = Serializer.getSerializer(serializer)
   private val serInstance = ser.newInstance()
 
@@ -116,13 +117,6 @@ private[spark] class ExternalSorter[K, V, C](
   private var _memoryBytesSpilled = 0L
   private var _diskBytesSpilled = 0L
 
-  // Collective memory threshold shared across all running tasks
-  private val maxMemoryThreshold = {
-    val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2)
-    val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8)
-    (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
-  }
-
   // How much of the shared memory pool this collection has claimed
   private var myMemoryThreshold = 0L
 
@@ -218,31 +212,15 @@ private[spark] class ExternalSorter[K, V, C](
     if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
         collection.estimateSize() >= myMemoryThreshold)
     {
-      // TODO: This logic doesn't work if there are two external collections being used in the same
-      // task (e.g. to read shuffle output and write it out into another shuffle) [SPARK-2711]
-
-      val currentSize = collection.estimateSize()
-      var shouldSpill = false
-      val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
-
-      // Atomically check whether there is sufficient memory in the global pool for
-      // us to double our threshold
-      shuffleMemoryMap.synchronized {
-        val threadId = Thread.currentThread().getId
-        val previouslyClaimedMemory = shuffleMemoryMap.get(threadId)
-        val availableMemory = maxMemoryThreshold -
-          (shuffleMemoryMap.values.sum - previouslyClaimedMemory.getOrElse(0L))
-
-        // Try to allocate at least 2x more memory, otherwise spill
-        shouldSpill = availableMemory < currentSize * 2
-        if (!shouldSpill) {
-          shuffleMemoryMap(threadId) = currentSize * 2
-          myMemoryThreshold = currentSize * 2
-        }
-      }
-      // Do not hold lock during spills
-      if (shouldSpill) {
-        spill(currentSize, usingMap)
+      // Claim up to double our current memory from the shuffle memory pool
+      val currentMemory = collection.estimateSize()
+      val amountToRequest = 2 * currentMemory - myMemoryThreshold
+      val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
+      myMemoryThreshold += granted
+      if (myMemoryThreshold <= currentMemory) {
+        // We were granted too little memory to grow further (either tryToAcquire returned 0,
+        // or we already had more memory than myMemoryThreshold); spill the current collection
+        spill(currentMemory, usingMap)  // Will also release memory back to ShuffleMemoryManager
       }
     }
   }
@@ -327,11 +305,8 @@ private[spark] class ExternalSorter[K, V, C](
       buffer = new SizeTrackingPairBuffer[(Int, K), C]
     }
 
-    // Reset the amount of shuffle memory used by this map in the global pool
-    val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
-    shuffleMemoryMap.synchronized {
-      shuffleMemoryMap(Thread.currentThread().getId) = 0
-    }
+    // Release our memory back to the shuffle pool so that other threads can grab it
+    shuffleMemoryManager.release(myMemoryThreshold)
     myMemoryThreshold = 0
 
     spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition))
diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..d31bc22ee74f755efad13bcbee19e565d993e6d2
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala
@@ -0,0 +1,294 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import org.scalatest.FunSuite
+import org.scalatest.concurrent.Timeouts
+import org.scalatest.time.SpanSugar._
+import java.util.concurrent.atomic.AtomicBoolean
+import java.util.concurrent.CountDownLatch
+
+class ShuffleMemoryManagerSuite extends FunSuite with Timeouts {
+  /** Launch a thread with the given body block and return it. */
+  private def startThread(name: String)(body: => Unit): Thread = {
+    val thread = new Thread("ShuffleMemorySuite " + name) {
+      override def run() {
+        body
+      }
+    }
+    thread.start()
+    thread
+  }
+
+  test("single thread requesting memory") {
+    val manager = new ShuffleMemoryManager(1000L)
+
+    assert(manager.tryToAcquire(100L) === 100L)
+    assert(manager.tryToAcquire(400L) === 400L)
+    assert(manager.tryToAcquire(400L) === 400L)
+    assert(manager.tryToAcquire(200L) === 100L)
+    assert(manager.tryToAcquire(100L) === 0L)
+    assert(manager.tryToAcquire(100L) === 0L)
+
+    manager.release(500L)
+    assert(manager.tryToAcquire(300L) === 300L)
+    assert(manager.tryToAcquire(300L) === 200L)
+
+    manager.releaseMemoryForThisThread()
+    assert(manager.tryToAcquire(1000L) === 1000L)
+    assert(manager.tryToAcquire(100L) === 0L)
+  }
+
+  test("two threads requesting full memory") {
+    // Two threads request 500 bytes first, wait for each other to get it, and then request
+    // 500 more; we should immediately return 0 as both are now at 1 / N
+
+    val manager = new ShuffleMemoryManager(1000L)
+
+    class State {
+      var t1Result1 = -1L
+      var t2Result1 = -1L
+      var t1Result2 = -1L
+      var t2Result2 = -1L
+    }
+    val state = new State
+
+    val t1 = startThread("t1") {
+      val r1 = manager.tryToAcquire(500L)
+      state.synchronized {
+        state.t1Result1 = r1
+        state.notifyAll()
+        while (state.t2Result1 === -1L) {
+          state.wait()
+        }
+      }
+      val r2 = manager.tryToAcquire(500L)
+      state.synchronized { state.t1Result2 = r2 }
+    }
+
+    val t2 = startThread("t2") {
+      val r1 = manager.tryToAcquire(500L)
+      state.synchronized {
+        state.t2Result1 = r1
+        state.notifyAll()
+        while (state.t1Result1 === -1L) {
+          state.wait()
+        }
+      }
+      val r2 = manager.tryToAcquire(500L)
+      state.synchronized { state.t2Result2 = r2 }
+    }
+
+    failAfter(20 seconds) {
+      t1.join()
+      t2.join()
+    }
+
+    assert(state.t1Result1 === 500L)
+    assert(state.t2Result1 === 500L)
+    assert(state.t1Result2 === 0L)
+    assert(state.t2Result2 === 0L)
+  }
+
+
+  test("threads cannot grow past 1 / N") {
+    // Two threads request 250 bytes first, wait for each other to get it, and then request
+    // 500 more; we should only grant 250 bytes to each of them on this second request
+
+    val manager = new ShuffleMemoryManager(1000L)
+
+    class State {
+      var t1Result1 = -1L
+      var t2Result1 = -1L
+      var t1Result2 = -1L
+      var t2Result2 = -1L
+    }
+    val state = new State
+
+    val t1 = startThread("t1") {
+      val r1 = manager.tryToAcquire(250L)
+      state.synchronized {
+        state.t1Result1 = r1
+        state.notifyAll()
+        while (state.t2Result1 === -1L) {
+          state.wait()
+        }
+      }
+      val r2 = manager.tryToAcquire(500L)
+      state.synchronized { state.t1Result2 = r2 }
+    }
+
+    val t2 = startThread("t2") {
+      val r1 = manager.tryToAcquire(250L)
+      state.synchronized {
+        state.t2Result1 = r1
+        state.notifyAll()
+        while (state.t1Result1 === -1L) {
+          state.wait()
+        }
+      }
+      val r2 = manager.tryToAcquire(500L)
+      state.synchronized { state.t2Result2 = r2 }
+    }
+
+    failAfter(20 seconds) {
+      t1.join()
+      t2.join()
+    }
+
+    assert(state.t1Result1 === 250L)
+    assert(state.t2Result1 === 250L)
+    assert(state.t1Result2 === 250L)
+    assert(state.t2Result2 === 250L)
+  }
+
+  test("threads can block to get at least 1 / 2N memory") {
+    // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps
+    // for a bit and releases 250 bytes, which should then be greanted to t2. Further requests
+    // by t2 will return false right away because it now has 1 / 2N of the memory.
+
+    val manager = new ShuffleMemoryManager(1000L)
+
+    class State {
+      var t1Requested = false
+      var t2Requested = false
+      var t1Result = -1L
+      var t2Result = -1L
+      var t2Result2 = -1L
+      var t2WaitTime = 0L
+    }
+    val state = new State
+
+    val t1 = startThread("t1") {
+      state.synchronized {
+        state.t1Result = manager.tryToAcquire(1000L)
+        state.t1Requested = true
+        state.notifyAll()
+        while (!state.t2Requested) {
+          state.wait()
+        }
+      }
+      // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make
+      // sure the other thread blocks for some time otherwise
+      Thread.sleep(300)
+      manager.release(250L)
+    }
+
+    val t2 = startThread("t2") {
+      state.synchronized {
+        while (!state.t1Requested) {
+          state.wait()
+        }
+        state.t2Requested = true
+        state.notifyAll()
+      }
+      val startTime = System.currentTimeMillis()
+      val result = manager.tryToAcquire(250L)
+      val endTime = System.currentTimeMillis()
+      state.synchronized {
+        state.t2Result = result
+        // A second call should return 0 because we're now already at 1 / 2N
+        state.t2Result2 = manager.tryToAcquire(100L)
+        state.t2WaitTime = endTime - startTime
+      }
+    }
+
+    failAfter(20 seconds) {
+      t1.join()
+      t2.join()
+    }
+
+    // Both threads should've been able to acquire their memory; the second one will have waited
+    // until the first one acquired 1000 bytes and then released 250
+    state.synchronized {
+      assert(state.t1Result === 1000L, "t1 could not allocate memory")
+      assert(state.t2Result === 250L, "t2 could not allocate memory")
+      assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})")
+      assert(state.t2Result2 === 0L, "t1 got extra memory the second time")
+    }
+  }
+
+  test("releaseMemoryForThisThread") {
+    // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps
+    // for a bit and releases all its memory. t2 should now be able to grab all the memory.
+
+    val manager = new ShuffleMemoryManager(1000L)
+
+    class State {
+      var t1Requested = false
+      var t2Requested = false
+      var t1Result = -1L
+      var t2Result1 = -1L
+      var t2Result2 = -1L
+      var t2Result3 = -1L
+      var t2WaitTime = 0L
+    }
+    val state = new State
+
+    val t1 = startThread("t1") {
+      state.synchronized {
+        state.t1Result = manager.tryToAcquire(1000L)
+        state.t1Requested = true
+        state.notifyAll()
+        while (!state.t2Requested) {
+          state.wait()
+        }
+      }
+      // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make
+      // sure the other thread blocks for some time otherwise
+      Thread.sleep(300)
+      manager.releaseMemoryForThisThread()
+    }
+
+    val t2 = startThread("t2") {
+      state.synchronized {
+        while (!state.t1Requested) {
+          state.wait()
+        }
+        state.t2Requested = true
+        state.notifyAll()
+      }
+      val startTime = System.currentTimeMillis()
+      val r1 = manager.tryToAcquire(500L)
+      val endTime = System.currentTimeMillis()
+      val r2 = manager.tryToAcquire(500L)
+      val r3 = manager.tryToAcquire(500L)
+      state.synchronized {
+        state.t2Result1 = r1
+        state.t2Result2 = r2
+        state.t2Result3 = r3
+        state.t2WaitTime = endTime - startTime
+      }
+    }
+
+    failAfter(20 seconds) {
+      t1.join()
+      t2.join()
+    }
+
+    // Both threads should've been able to acquire their memory; the second one will have waited
+    // until the first one acquired 1000 bytes and then released all of it
+    state.synchronized {
+      assert(state.t1Result === 1000L, "t1 could not allocate memory")
+      assert(state.t2Result1 === 500L, "t2 didn't get 500 bytes the first time")
+      assert(state.t2Result2 === 500L, "t2 didn't get 500 bytes the second time")
+      assert(state.t2Result3 === 0L, s"t2 got more bytes a third time (${state.t2Result3})")
+      assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})")
+    }
+  }
+}