Skip to content
Snippets Groups Projects
Commit 67e085ef authored by Ryan Blue's avatar Ryan Blue Committed by Reynold Xin
Browse files

[SPARK-16420] Ensure compression streams are closed.

## What changes were proposed in this pull request?

This uses the try/finally pattern to ensure streams are closed after use. `UnsafeShuffleWriter` wasn't closing compression streams, causing them to leak resources until garbage collected. This was causing a problem with codecs that use off-heap memory.

## How was this patch tested?

Current tests are sufficient. This should not change behavior.

Author: Ryan Blue <blue@apache.org>

Closes #14093 from rdblue/SPARK-16420-unsafe-shuffle-writer-leak.
parent 38cf8f2a
No related branches found
No related tags found
No related merge requests found
...@@ -48,11 +48,27 @@ import com.google.common.base.Preconditions; ...@@ -48,11 +48,27 @@ import com.google.common.base.Preconditions;
* use this functionality in both a Guava 11 environment and a Guava &gt;14 environment. * use this functionality in both a Guava 11 environment and a Guava &gt;14 environment.
*/ */
public final class LimitedInputStream extends FilterInputStream { public final class LimitedInputStream extends FilterInputStream {
private final boolean closeWrappedStream;
private long left; private long left;
private long mark = -1; private long mark = -1;
public LimitedInputStream(InputStream in, long limit) { public LimitedInputStream(InputStream in, long limit) {
this(in, limit, true);
}
/**
* Create a LimitedInputStream that will read {@code limit} bytes from {@code in}.
* <p>
* If {@code closeWrappedStream} is true, this will close {@code in} when it is closed.
* Otherwise, the stream is left open for reading its remaining content.
*
* @param in a {@link InputStream} to read from
* @param limit the number of bytes to read
* @param closeWrappedStream whether to close {@code in} when {@link #close} is called
*/
public LimitedInputStream(InputStream in, long limit, boolean closeWrappedStream) {
super(in); super(in);
this.closeWrappedStream = closeWrappedStream;
Preconditions.checkNotNull(in); Preconditions.checkNotNull(in);
Preconditions.checkArgument(limit >= 0, "limit must be non-negative"); Preconditions.checkArgument(limit >= 0, "limit must be non-negative");
left = limit; left = limit;
...@@ -102,4 +118,11 @@ public final class LimitedInputStream extends FilterInputStream { ...@@ -102,4 +118,11 @@ public final class LimitedInputStream extends FilterInputStream {
left -= skipped; left -= skipped;
return skipped; return skipped;
} }
@Override
public void close() throws IOException {
if (closeWrappedStream) {
super.close();
}
}
} }
...@@ -349,12 +349,19 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> { ...@@ -349,12 +349,19 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
for (int i = 0; i < spills.length; i++) { for (int i = 0; i < spills.length; i++) {
final long partitionLengthInSpill = spills[i].partitionLengths[partition]; final long partitionLengthInSpill = spills[i].partitionLengths[partition];
if (partitionLengthInSpill > 0) { if (partitionLengthInSpill > 0) {
InputStream partitionInputStream = InputStream partitionInputStream = null;
new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill); boolean innerThrewException = true;
if (compressionCodec != null) { try {
partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); partitionInputStream =
new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false);
if (compressionCodec != null) {
partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
}
ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
innerThrewException = false;
} finally {
Closeables.close(partitionInputStream, innerThrewException);
} }
ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
} }
} }
mergedFileOutputStream.flush(); mergedFileOutputStream.flush();
......
...@@ -232,7 +232,11 @@ private object TorrentBroadcast extends Logging { ...@@ -232,7 +232,11 @@ private object TorrentBroadcast extends Logging {
val out = compressionCodec.map(c => c.compressedOutputStream(cbbos)).getOrElse(cbbos) val out = compressionCodec.map(c => c.compressedOutputStream(cbbos)).getOrElse(cbbos)
val ser = serializer.newInstance() val ser = serializer.newInstance()
val serOut = ser.serializeStream(out) val serOut = ser.serializeStream(out)
serOut.writeObject[T](obj).close() Utils.tryWithSafeFinally {
serOut.writeObject[T](obj)
} {
serOut.close()
}
cbbos.toChunkedByteBuffer.getChunks() cbbos.toChunkedByteBuffer.getChunks()
} }
...@@ -246,8 +250,11 @@ private object TorrentBroadcast extends Logging { ...@@ -246,8 +250,11 @@ private object TorrentBroadcast extends Logging {
val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is) val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is)
val ser = serializer.newInstance() val ser = serializer.newInstance()
val serIn = ser.deserializeStream(in) val serIn = ser.deserializeStream(in)
val obj = serIn.readObject[T]() val obj = Utils.tryWithSafeFinally {
serIn.close() serIn.readObject[T]()
} {
serIn.close()
}
obj obj
} }
......
...@@ -32,6 +32,7 @@ import org.apache.commons.io.IOUtils ...@@ -32,6 +32,7 @@ import org.apache.commons.io.IOUtils
import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.{SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec import org.apache.spark.io.CompressionCodec
import org.apache.spark.util.Utils
/** /**
* Custom serializer used for generic Avro records. If the user registers the schemas * Custom serializer used for generic Avro records. If the user registers the schemas
...@@ -72,8 +73,11 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) ...@@ -72,8 +73,11 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String])
def compress(schema: Schema): Array[Byte] = compressCache.getOrElseUpdate(schema, { def compress(schema: Schema): Array[Byte] = compressCache.getOrElseUpdate(schema, {
val bos = new ByteArrayOutputStream() val bos = new ByteArrayOutputStream()
val out = codec.compressedOutputStream(bos) val out = codec.compressedOutputStream(bos)
out.write(schema.toString.getBytes(StandardCharsets.UTF_8)) Utils.tryWithSafeFinally {
out.close() out.write(schema.toString.getBytes(StandardCharsets.UTF_8))
} {
out.close()
}
bos.toByteArray bos.toByteArray
}) })
...@@ -86,7 +90,12 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) ...@@ -86,7 +90,12 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String])
schemaBytes.array(), schemaBytes.array(),
schemaBytes.arrayOffset() + schemaBytes.position(), schemaBytes.arrayOffset() + schemaBytes.position(),
schemaBytes.remaining()) schemaBytes.remaining())
val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis)) val in = codec.compressedInputStream(bis)
val bytes = Utils.tryWithSafeFinally {
IOUtils.toByteArray(in)
} {
in.close()
}
new Schema.Parser().parse(new String(bytes, StandardCharsets.UTF_8)) new Schema.Parser().parse(new String(bytes, StandardCharsets.UTF_8))
}) })
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment