diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index f6c9b87778f8fcbaf4e9305c4ccaf8cdb1d8f064..d80bd57bd2048f688b1a008f8fbf9a65bd824ab3 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -20,6 +20,7 @@ package org.apache.spark.unsafe.types; import javax.annotation.Nonnull; import java.io.Serializable; import java.io.UnsupportedEncodingException; +import java.nio.ByteOrder; import java.util.Arrays; import org.apache.spark.unsafe.PlatformDependent; @@ -53,6 +54,8 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { 5, 5, 5, 5, 6, 6}; + private static ByteOrder byteOrder = ByteOrder.nativeOrder(); + public static final UTF8String EMPTY_UTF8 = UTF8String.fromString(""); /** @@ -175,18 +178,35 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable { // If size is greater than 4, assume we have at least 8 bytes of data to fetch. // After getting the data, we use a mask to mask out data that is not part of the string. long p; - if (numBytes >= 8) { - p = PlatformDependent.UNSAFE.getLong(base, offset); - } else if (numBytes > 4) { - p = PlatformDependent.UNSAFE.getLong(base, offset); - p = p & ((1L << numBytes * 8) - 1); - } else if (numBytes > 0) { - p = (long) PlatformDependent.UNSAFE.getInt(base, offset); - p = p & ((1L << numBytes * 8) - 1); + long mask = 0; + if (byteOrder == ByteOrder.LITTLE_ENDIAN) { + if (numBytes >= 8) { + p = PlatformDependent.UNSAFE.getLong(base, offset); + } else if (numBytes > 4) { + p = PlatformDependent.UNSAFE.getLong(base, offset); + mask = (1L << (8 - numBytes) * 8) - 1; + } else if (numBytes > 0) { + p = (long) PlatformDependent.UNSAFE.getInt(base, offset); + mask = (1L << (8 - numBytes) * 8) - 1; + } else { + p = 0; + } + p = java.lang.Long.reverseBytes(p); } else { - p = 0; + // byteOrder == ByteOrder.BIG_ENDIAN + if (numBytes >= 8) { + p = PlatformDependent.UNSAFE.getLong(base, offset); + } else if (numBytes > 4) { + p = PlatformDependent.UNSAFE.getLong(base, offset); + mask = (1L << (8 - numBytes) * 8) - 1; + } else if (numBytes > 0) { + p = ((long) PlatformDependent.UNSAFE.getInt(base, offset)) << 32; + mask = (1L << (8 - numBytes) * 8) - 1; + } else { + p = 0; + } } - p = java.lang.Long.reverseBytes(p); + p &= ~mask; return p; }