diff --git a/core/src/main/scala/spark/BoundedMemoryCache.scala b/core/src/main/scala/spark/BoundedMemoryCache.scala index 19d9bebfe5ec08812fbfe33eba2c259294c5ad20..10143d3dd22bb24ab2a05044facd84f5636cea70 100644 --- a/core/src/main/scala/spark/BoundedMemoryCache.scala +++ b/core/src/main/scala/spark/BoundedMemoryCache.scala @@ -60,10 +60,14 @@ class BoundedMemoryCache extends Cache with Logging { val iter = map.entrySet.iterator while (maxBytes - currentBytes < space && iter.hasNext) { val mapEntry = iter.next() - logInfo("Dropping key %s of size %d to make space".format( - mapEntry.getKey, mapEntry.getValue.size)) + dropEntry(mapEntry.getKey, mapEntry.getValue) currentBytes -= mapEntry.getValue.size iter.remove() } } + + protected def dropEntry(key: Any, entry: Entry) { + logInfo("Dropping key %s of size %d to make space".format( + key, entry.size)) + } } diff --git a/core/src/main/scala/spark/DiskSpillingCache.scala b/core/src/main/scala/spark/DiskSpillingCache.scala new file mode 100644 index 0000000000000000000000000000000000000000..9e52fee69e44858c5fbdf869d068d0319d39f46b --- /dev/null +++ b/core/src/main/scala/spark/DiskSpillingCache.scala @@ -0,0 +1,76 @@ +package spark + +import java.io.File +import java.io.{FileOutputStream,FileInputStream} +import java.io.IOException +import java.util.LinkedHashMap +import java.util.UUID + +// TODO: cache into a separate directory using Utils.createTempDir +// TODO: clean up disk cache afterwards + +class DiskSpillingCache extends BoundedMemoryCache { + private val diskMap = new LinkedHashMap[Any, File](32, 0.75f, true) + + override def get(key: Any): Any = { + synchronized { + val ser = Serializer.newInstance() + super.get(key) match { + case bytes: Any => // found in memory + ser.deserialize(bytes.asInstanceOf[Array[Byte]]) + + case _ => diskMap.get(key) match { + case file: Any => // found on disk + try { + val startTime = System.currentTimeMillis + val bytes = new Array[Byte](file.length.toInt) + new FileInputStream(file).read(bytes) + val timeTaken = System.currentTimeMillis - startTime + logInfo("Reading key %s of size %d bytes from disk took %d ms".format( + key, file.length, timeTaken)) + super.put(key, bytes) + ser.deserialize(bytes.asInstanceOf[Array[Byte]]) + } catch { + case e: IOException => + logWarning("Failed to read key %s from disk at %s: %s".format( + key, file.getPath(), e.getMessage())) + diskMap.remove(key) // remove dead entry + null + } + + case _ => // not found + null + } + } + } + } + + override def put(key: Any, value: Any) { + var ser = Serializer.newInstance() + super.put(key, ser.serialize(value)) + } + + /** + * Spill the given entry to disk. Assumes that a lock is held on the + * DiskSpillingCache. Assumes that entry.value is a byte array. + */ + override protected def dropEntry(key: Any, entry: Entry) { + logInfo("Spilling key %s of size %d to make space".format( + key, entry.size)) + val cacheDir = System.getProperty( + "spark.diskSpillingCache.cacheDir", + System.getProperty("java.io.tmpdir")) + val file = new File(cacheDir, "spark-dsc-" + UUID.randomUUID.toString) + try { + val stream = new FileOutputStream(file) + stream.write(entry.value.asInstanceOf[Array[Byte]]) + stream.close() + diskMap.put(key, file) + } catch { + case e: IOException => + logWarning("Failed to spill key %s to disk at %s: %s".format( + key, file.getPath(), e.getMessage())) + // Do nothing and let the entry be discarded + } + } +}