diff --git a/core/src/main/scala/spark/storage/MemoryStore.scala b/core/src/main/scala/spark/storage/MemoryStore.scala index ae88ff0bb1694c89a580bdd89bf127991ecb7f1e..949588476c20150b1dd5c73f4303dbf85d2ad518 100644 --- a/core/src/main/scala/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/spark/storage/MemoryStore.scala @@ -32,8 +32,8 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + bytes.rewind() if (level.deserialized) { - bytes.rewind() val values = blockManager.dataDeserialize(blockId, bytes) val elements = new ArrayBuffer[Any] elements ++= values @@ -58,7 +58,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } else { val bytes = blockManager.dataSerialize(blockId, values.iterator) tryToPut(blockId, bytes, bytes.limit, false) - PutResult(bytes.limit(), Right(bytes)) + PutResult(bytes.limit(), Right(bytes.duplicate())) } } diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala index caa4ba3a3705af5587099951c1914a03663b5ea8..4104b33c8b6815ddebbe50ea595e633fb0cba46e 100644 --- a/core/src/test/scala/spark/DistributedSuite.scala +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -1,5 +1,6 @@ package spark +import network.ConnectionManagerId import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter import org.scalatest.matchers.ShouldMatchers @@ -13,7 +14,7 @@ import com.google.common.io.Files import scala.collection.mutable.ArrayBuffer import SparkContext._ -import storage.StorageLevel +import storage.{GetBlock, BlockManagerWorker, StorageLevel} class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter with LocalSparkContext { @@ -140,9 +141,22 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter test("caching in memory and disk, serialized, replicated") { sc = new SparkContext(clusterUrl, "test") val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_SER_2) + assert(data.count() === 1000) assert(data.count() === 1000) assert(data.count() === 1000) + + // Get all the locations of the first partition and try to fetch the partitions + // from those locations. + val blockIds = data.partitions.indices.map(index => "rdd_%d_%d".format(data.id, index)).toArray + val blockId = blockIds(0) + val blockManager = SparkEnv.get.blockManager + blockManager.master.getLocations(blockId).foreach(id => { + val bytes = BlockManagerWorker.syncGetBlock( + GetBlock(blockId), ConnectionManagerId(id.ip, id.port)) + val deserialized = blockManager.dataDeserialize(blockId, bytes).asInstanceOf[Iterator[Int]].toList + assert(deserialized === (1 to 100).toList) + }) } test("compute without caching when no partitions fit in memory") {