diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index 5e45b375ddd452ea130f701afb47c8fc388e4549..2ec2f2031aa455abb085596ab701854590476abe 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -197,7 +197,8 @@ private[spark] object HighlyCompressedMapStatus {
     // block as being non-empty (or vice-versa) when using the average block size.
     var i = 0
     var numNonEmptyBlocks: Int = 0
-    var totalSize: Long = 0
+    var numSmallBlocks: Int = 0
+    var totalSmallBlockSize: Long = 0
     // From a compression standpoint, it shouldn't matter whether we track empty or non-empty
     // blocks. From a performance standpoint, we benefit from tracking empty blocks because
     // we expect that there will be far fewer of them, so we will perform fewer bitmap insertions.
@@ -214,7 +215,8 @@ private[spark] object HighlyCompressedMapStatus {
         // Huge blocks are not included in the calculation for average size, thus size for smaller
         // blocks is more accurate.
         if (size < threshold) {
-          totalSize += size
+          totalSmallBlockSize += size
+          numSmallBlocks += 1
         } else {
           hugeBlockSizesArray += Tuple2(i, MapStatus.compressSize(uncompressedSizes(i)))
         }
@@ -223,8 +225,8 @@ private[spark] object HighlyCompressedMapStatus {
       }
       i += 1
     }
-    val avgSize = if (numNonEmptyBlocks > 0) {
-      totalSize / numNonEmptyBlocks
+    val avgSize = if (numSmallBlocks > 0) {
+      totalSmallBlockSize / numSmallBlocks
     } else {
       0
     }
diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
index 144e5afdcdd78609fda2f9b02a718ea48afff7d9..2155a0f2b6c214eac699e5a3ed64e9043052eba8 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
@@ -98,6 +98,28 @@ class MapStatusSuite extends SparkFunSuite {
     }
   }
 
+  test("SPARK-22540: ensure HighlyCompressedMapStatus calculates correct avgSize") {
+    val threshold = 1000
+    val conf = new SparkConf().set(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key, threshold.toString)
+    val env = mock(classOf[SparkEnv])
+    doReturn(conf).when(env).conf
+    SparkEnv.set(env)
+    val sizes = (0L to 3000L).toArray
+    val smallBlockSizes = sizes.filter(n => n > 0 && n < threshold)
+    val avg = smallBlockSizes.sum / smallBlockSizes.length
+    val loc = BlockManagerId("a", "b", 10)
+    val status = MapStatus(loc, sizes)
+    val status1 = compressAndDecompressMapStatus(status)
+    assert(status1.isInstanceOf[HighlyCompressedMapStatus])
+    assert(status1.location == loc)
+    for (i <- 0 until threshold) {
+      val estimate = status1.getSizeForBlock(i)
+      if (sizes(i) > 0) {
+        assert(estimate === avg)
+      }
+    }
+  }
+
   def compressAndDecompressMapStatus(status: MapStatus): MapStatus = {
     val ser = new JavaSerializer(new SparkConf)
     val buf = ser.newInstance().serialize(status)