diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 0c5c12b7a83a4f86a8d3118819d2f772aae50981..fe932d8ede2f3a480eb5b0f2ce2eddce853d8a89 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,13 +18,12 @@ package org.apache.spark.util import java.io._ -import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket} +import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address} import java.util.{Locale, Random, UUID} -import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor} -import java.util.regex.Pattern +import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor} import scala.collection.Map -import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ import scala.io.Source @@ -36,8 +35,7 @@ import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} import org.apache.spark.deploy.SparkHadoopUtil import java.nio.ByteBuffer -import org.apache.spark.{SparkEnv, SparkException, Logging} -import java.util.ConcurrentModificationException +import org.apache.spark.{SparkException, Logging} /** @@ -149,7 +147,7 @@ private[spark] object Utils extends Logging { return buf } - private val shutdownDeletePaths = new collection.mutable.HashSet[String]() + private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() // Register the path to be deleted via shutdown hook def registerShutdownDeleteDir(file: File) { diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala new file mode 100644 index 0000000000000000000000000000000000000000..a1a452315d1437d35ff674b496371224edd98dea --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + + +/** + * A simple, fixed-size bit set implementation. This implementation is fast because it avoids + * safety/bound checking. + */ +class BitSet(numBits: Int) { + + private[this] val words = new Array[Long](bit2words(numBits)) + private[this] val numWords = words.length + + /** + * Sets the bit at the specified index to true. + * @param index the bit index + */ + def set(index: Int) { + val bitmask = 1L << (index & 0x3f) // mod 64 and shift + words(index >> 6) |= bitmask // div by 64 and mask + } + + /** + * Return the value of the bit with the specified index. The value is true if the bit with + * the index is currently set in this BitSet; otherwise, the result is false. + * + * @param index the bit index + * @return the value of the bit with the specified index + */ + def get(index: Int): Boolean = { + val bitmask = 1L << (index & 0x3f) // mod 64 and shift + (words(index >> 6) & bitmask) != 0 // div by 64 and mask + } + + /** Return the number of bits set to true in this BitSet. */ + def cardinality(): Int = { + var sum = 0 + var i = 0 + while (i < numWords) { + sum += java.lang.Long.bitCount(words(i)) + i += 1 + } + sum + } + + /** + * Returns the index of the first bit that is set to true that occurs on or after the + * specified starting index. If no such bit exists then -1 is returned. + * + * To iterate over the true bits in a BitSet, use the following loop: + * + * for (int i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i+1)) { + * // operate on index i here + * } + * + * @param fromIndex the index to start checking from (inclusive) + * @return the index of the next set bit, or -1 if there is no such bit + */ + def nextSetBit(fromIndex: Int): Int = { + var wordIndex = fromIndex >> 6 + if (wordIndex >= numWords) { + return -1 + } + + // Try to find the next set bit in the current word + val subIndex = fromIndex & 0x3f + var word = words(wordIndex) >> subIndex + if (word != 0) { + return (wordIndex << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word) + } + + // Find the next set bit in the rest of the words + wordIndex += 1 + while (wordIndex < numWords) { + word = words(wordIndex) + if (word != 0) { + return (wordIndex << 6) + java.lang.Long.numberOfTrailingZeros(word) + } + wordIndex += 1 + } + + -1 + } + + /** Return the number of longs it would take to hold numBits. */ + private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1 +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala new file mode 100644 index 0000000000000000000000000000000000000000..80545c9688aa603cd3cd84f263ad0f4e54fa0b97 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + + +/** + * A fast hash map implementation for nullable keys. This hash map supports insertions and updates, + * but not deletions. This map is about 5X faster than java.util.HashMap, while using much less + * space overhead. + * + * Under the hood, it uses our OpenHashSet implementation. + */ +private[spark] +class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V: ClassManifest]( + initialCapacity: Int) + extends Iterable[(K, V)] + with Serializable { + + def this() = this(64) + + protected var _keySet = new OpenHashSet[K](initialCapacity) + + // Init in constructor (instead of in declaration) to work around a Scala compiler specialization + // bug that would generate two arrays (one for Object and one for specialized T). + private var _values: Array[V] = _ + _values = new Array[V](_keySet.capacity) + + @transient private var _oldValues: Array[V] = null + + // Treat the null key differently so we can use nulls in "data" to represent empty items. + private var haveNullValue = false + private var nullValue: V = null.asInstanceOf[V] + + override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size + + /** Get the value for a given key */ + def apply(k: K): V = { + if (k == null) { + nullValue + } else { + val pos = _keySet.getPos(k) + if (pos < 0) { + null.asInstanceOf[V] + } else { + _values(pos) + } + } + } + + /** Set the value for a key */ + def update(k: K, v: V) { + if (k == null) { + haveNullValue = true + nullValue = v + } else { + val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK + _values(pos) = v + _keySet.rehashIfNeeded(k, grow, move) + _oldValues = null + } + } + + /** + * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise, + * set its value to mergeValue(oldValue). + * + * @return the newly updated value. + */ + def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = { + if (k == null) { + if (haveNullValue) { + nullValue = mergeValue(nullValue) + } else { + haveNullValue = true + nullValue = defaultValue + } + nullValue + } else { + val pos = _keySet.addWithoutResize(k) + if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { + val newValue = defaultValue + _values(pos & OpenHashSet.POSITION_MASK) = newValue + _keySet.rehashIfNeeded(k, grow, move) + newValue + } else { + _values(pos) = mergeValue(_values(pos)) + _values(pos) + } + } + } + + override def iterator = new Iterator[(K, V)] { + var pos = -1 + var nextPair: (K, V) = computeNextPair() + + /** Get the next value we should return from next(), or null if we're finished iterating */ + def computeNextPair(): (K, V) = { + if (pos == -1) { // Treat position -1 as looking at the null value + if (haveNullValue) { + pos += 1 + return (null.asInstanceOf[K], nullValue) + } + pos += 1 + } + pos = _keySet.nextPos(pos) + if (pos >= 0) { + val ret = (_keySet.getValue(pos), _values(pos)) + pos += 1 + ret + } else { + null + } + } + + def hasNext = nextPair != null + + def next() = { + val pair = nextPair + nextPair = computeNextPair() + pair + } + } + + // The following member variables are declared as protected instead of private for the + // specialization to work (specialized class extends the non-specialized one and needs access + // to the "private" variables). + // They also should have been val's. We use var's because there is a Scala compiler bug that + // would throw illegal access error at runtime if they are declared as val's. + protected var grow = (newCapacity: Int) => { + _oldValues = _values + _values = new Array[V](newCapacity) + } + + protected var move = (oldPos: Int, newPos: Int) => { + _values(newPos) = _oldValues(oldPos) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala new file mode 100644 index 0000000000000000000000000000000000000000..4592e4f939e5c570b0a19abccd26430e35e2143c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + + +/** + * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never + * removed. + * + * The underlying implementation uses Scala compiler's specialization to generate optimized + * storage for two primitive types (Long and Int). It is much faster than Java's standard HashSet + * while incurring much less memory overhead. This can serve as building blocks for higher level + * data structures such as an optimized HashMap. + * + * This OpenHashSet is designed to serve as building blocks for higher level data structures + * such as an optimized hash map. Compared with standard hash set implementations, this class + * provides its various callbacks interfaces (e.g. allocateFunc, moveFunc) and interfaces to + * retrieve the position of a key in the underlying array. + * + * It uses quadratic probing with a power-of-2 hash table size, which is guaranteed + * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing). + */ +private[spark] +class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( + initialCapacity: Int, + loadFactor: Double) + extends Serializable { + + require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") + require(initialCapacity >= 1, "Invalid initial capacity") + require(loadFactor < 1.0, "Load factor must be less than 1.0") + require(loadFactor > 0.0, "Load factor must be greater than 0.0") + + import OpenHashSet._ + + def this(initialCapacity: Int) = this(initialCapacity, 0.7) + + def this() = this(64) + + // The following member variables are declared as protected instead of private for the + // specialization to work (specialized class extends the non-specialized one and needs access + // to the "private" variables). + + protected val hasher: Hasher[T] = { + // It would've been more natural to write the following using pattern matching. But Scala 2.9.x + // compiler has a bug when specialization is used together with this pattern matching, and + // throws: + // scala.tools.nsc.symtab.Types$TypeError: type mismatch; + // found : scala.reflect.AnyValManifest[Long] + // required: scala.reflect.ClassManifest[Int] + // at scala.tools.nsc.typechecker.Contexts$Context.error(Contexts.scala:298) + // at scala.tools.nsc.typechecker.Infer$Inferencer.error(Infer.scala:207) + // ... + val mt = classManifest[T] + if (mt == ClassManifest.Long) { + (new LongHasher).asInstanceOf[Hasher[T]] + } else if (mt == ClassManifest.Int) { + (new IntHasher).asInstanceOf[Hasher[T]] + } else { + new Hasher[T] + } + } + + protected var _capacity = nextPowerOf2(initialCapacity) + protected var _mask = _capacity - 1 + protected var _size = 0 + + protected var _bitset = new BitSet(_capacity) + + // Init of the array in constructor (instead of in declaration) to work around a Scala compiler + // specialization bug that would generate two arrays (one for Object and one for specialized T). + protected var _data: Array[T] = _ + _data = new Array[T](_capacity) + + /** Number of elements in the set. */ + def size: Int = _size + + /** The capacity of the set (i.e. size of the underlying array). */ + def capacity: Int = _capacity + + /** Return true if this set contains the specified element. */ + def contains(k: T): Boolean = getPos(k) != INVALID_POS + + /** + * Add an element to the set. If the set is over capacity after the insertion, grow the set + * and rehash all elements. + */ + def add(k: T) { + addWithoutResize(k) + rehashIfNeeded(k, grow, move) + } + + /** + * Add an element to the set. This one differs from add in that it doesn't trigger rehashing. + * The caller is responsible for calling rehashIfNeeded. + * + * Use (retval & POSITION_MASK) to get the actual position, and + * (retval & EXISTENCE_MASK) != 0 for prior existence. + * + * @return The position where the key is placed, plus the highest order bit is set if the key + * exists previously. + */ + def addWithoutResize(k: T): Int = putInto(_bitset, _data, k) + + /** + * Rehash the set if it is overloaded. + * @param k A parameter unused in the function, but to force the Scala compiler to specialize + * this method. + * @param allocateFunc Callback invoked when we are allocating a new, larger array. + * @param moveFunc Callback invoked when we move the key from one position (in the old data array) + * to a new position (in the new data array). + */ + def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { + if (_size > loadFactor * _capacity) { + rehash(k, allocateFunc, moveFunc) + } + } + + /** + * Return the position of the element in the underlying array, or INVALID_POS if it is not found. + */ + def getPos(k: T): Int = { + var pos = hashcode(hasher.hash(k)) & _mask + var i = 1 + while (true) { + if (!_bitset.get(pos)) { + return INVALID_POS + } else if (k == _data(pos)) { + return pos + } else { + val delta = i + pos = (pos + delta) & _mask + i += 1 + } + } + // Never reached here + INVALID_POS + } + + /** Return the value at the specified position. */ + def getValue(pos: Int): T = _data(pos) + + /** + * Return the next position with an element stored, starting from the given position inclusively. + */ + def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos) + + /** + * Put an entry into the set. Return the position where the key is placed. In addition, the + * highest bit in the returned position is set if the key exists prior to this put. + * + * This function assumes the data array has at least one empty slot. + */ + private def putInto(bitset: BitSet, data: Array[T], k: T): Int = { + val mask = data.length - 1 + var pos = hashcode(hasher.hash(k)) & mask + var i = 1 + while (true) { + if (!bitset.get(pos)) { + // This is a new key. + data(pos) = k + bitset.set(pos) + _size += 1 + return pos | NONEXISTENCE_MASK + } else if (data(pos) == k) { + // Found an existing key. + return pos + } else { + val delta = i + pos = (pos + delta) & mask + i += 1 + } + } + // Never reached here + assert(INVALID_POS != INVALID_POS) + INVALID_POS + } + + /** + * Double the table's size and re-hash everything. We are not really using k, but it is declared + * so Scala compiler can specialize this method (which leads to calling the specialized version + * of putInto). + * + * @param k A parameter unused in the function, but to force the Scala compiler to specialize + * this method. + * @param allocateFunc Callback invoked when we are allocating a new, larger array. + * @param moveFunc Callback invoked when we move the key from one position (in the old data array) + * to a new position (in the new data array). + */ + private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { + val newCapacity = _capacity * 2 + require(newCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") + + allocateFunc(newCapacity) + val newData = new Array[T](newCapacity) + val newBitset = new BitSet(newCapacity) + var pos = 0 + _size = 0 + while (pos < _capacity) { + if (_bitset.get(pos)) { + val newPos = putInto(newBitset, newData, _data(pos)) + moveFunc(pos, newPos & POSITION_MASK) + } + pos += 1 + } + _bitset = newBitset + _data = newData + _capacity = newCapacity + _mask = newCapacity - 1 + } + + /** + * Re-hash a value to deal better with hash functions that don't differ + * in the lower bits, similar to java.util.HashMap + */ + private def hashcode(h: Int): Int = { + val r = h ^ (h >>> 20) ^ (h >>> 12) + r ^ (r >>> 7) ^ (r >>> 4) + } + + private def nextPowerOf2(n: Int): Int = { + val highBit = Integer.highestOneBit(n) + if (highBit == n) n else highBit << 1 + } +} + + +private[spark] +object OpenHashSet { + + val INVALID_POS = -1 + val NONEXISTENCE_MASK = 0x80000000 + val POSITION_MASK = 0xEFFFFFF + + /** + * A set of specialized hash function implementation to avoid boxing hash code computation + * in the specialized implementation of OpenHashSet. + */ + sealed class Hasher[@specialized(Long, Int) T] { + def hash(o: T): Int = o.hashCode() + } + + class LongHasher extends Hasher[Long] { + override def hash(o: Long): Int = (o ^ (o >>> 32)).toInt + } + + class IntHasher extends Hasher[Int] { + override def hash(o: Int): Int = o + } + + private def grow1(newSize: Int) {} + private def move1(oldPos: Int, newPos: Int) { } + + private val grow = grow1 _ + private val move = move1 _ +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala new file mode 100644 index 0000000000000000000000000000000000000000..4adf9cfb7611204a5c0d63e2ebc1a3d57dbc931c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + + +/** + * A fast hash map implementation for primitive, non-null keys. This hash map supports + * insertions and updates, but not deletions. This map is about an order of magnitude + * faster than java.util.HashMap, while using much less space overhead. + * + * Under the hood, it uses our OpenHashSet implementation. + */ +private[spark] +class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest, + @specialized(Long, Int, Double) V: ClassManifest]( + initialCapacity: Int) + extends Iterable[(K, V)] + with Serializable { + + def this() = this(64) + + require(classManifest[K] == classManifest[Long] || classManifest[K] == classManifest[Int]) + + // Init in constructor (instead of in declaration) to work around a Scala compiler specialization + // bug that would generate two arrays (one for Object and one for specialized T). + protected var _keySet: OpenHashSet[K] = _ + private var _values: Array[V] = _ + _keySet = new OpenHashSet[K](initialCapacity) + _values = new Array[V](_keySet.capacity) + + private var _oldValues: Array[V] = null + + override def size = _keySet.size + + /** Get the value for a given key */ + def apply(k: K): V = { + val pos = _keySet.getPos(k) + _values(pos) + } + + /** Set the value for a key */ + def update(k: K, v: V) { + val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK + _values(pos) = v + _keySet.rehashIfNeeded(k, grow, move) + _oldValues = null + } + + /** + * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise, + * set its value to mergeValue(oldValue). + * + * @return the newly updated value. + */ + def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = { + val pos = _keySet.addWithoutResize(k) + if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { + val newValue = defaultValue + _values(pos & OpenHashSet.POSITION_MASK) = newValue + _keySet.rehashIfNeeded(k, grow, move) + newValue + } else { + _values(pos) = mergeValue(_values(pos)) + _values(pos) + } + } + + override def iterator = new Iterator[(K, V)] { + var pos = 0 + var nextPair: (K, V) = computeNextPair() + + /** Get the next value we should return from next(), or null if we're finished iterating */ + def computeNextPair(): (K, V) = { + pos = _keySet.nextPos(pos) + if (pos >= 0) { + val ret = (_keySet.getValue(pos), _values(pos)) + pos += 1 + ret + } else { + null + } + } + + def hasNext = nextPair != null + + def next() = { + val pair = nextPair + nextPair = computeNextPair() + pair + } + } + + // The following member variables are declared as protected instead of private for the + // specialization to work (specialized class extends the unspecialized one and needs access + // to the "private" variables). + // They also should have been val's. We use var's because there is a Scala compiler bug that + // would throw illegal access error at runtime if they are declared as val's. + protected var grow = (newCapacity: Int) => { + _oldValues = _values + _values = new Array[V](newCapacity) + } + + protected var move = (oldPos: Int, newPos: Int) => { + _values(newPos) = _oldValues(oldPos) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..0f1ab3d20eea4456385f26df2c724595450e6234 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/BitSetSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import org.scalatest.FunSuite + + +class BitSetSuite extends FunSuite { + + test("basic set and get") { + val setBits = Seq(0, 9, 1, 10, 90, 96) + val bitset = new BitSet(100) + + for (i <- 0 until 100) { + assert(!bitset.get(i)) + } + + setBits.foreach(i => bitset.set(i)) + + for (i <- 0 until 100) { + if (setBits.contains(i)) { + assert(bitset.get(i)) + } else { + assert(!bitset.get(i)) + } + } + assert(bitset.cardinality() === setBits.size) + } + + test("100% full bit set") { + val bitset = new BitSet(10000) + for (i <- 0 until 10000) { + assert(!bitset.get(i)) + bitset.set(i) + } + for (i <- 0 until 10000) { + assert(bitset.get(i)) + } + assert(bitset.cardinality() === 10000) + } + + test("nextSetBit") { + val setBits = Seq(0, 9, 1, 10, 90, 96) + val bitset = new BitSet(100) + setBits.foreach(i => bitset.set(i)) + + assert(bitset.nextSetBit(0) === 0) + assert(bitset.nextSetBit(1) === 1) + assert(bitset.nextSetBit(2) === 9) + assert(bitset.nextSetBit(9) === 9) + assert(bitset.nextSetBit(10) === 10) + assert(bitset.nextSetBit(11) === 90) + assert(bitset.nextSetBit(80) === 90) + assert(bitset.nextSetBit(91) === 96) + assert(bitset.nextSetBit(96) === 96) + assert(bitset.nextSetBit(97) === -1) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..ca3f684668d605e868d491fffb5b4f0bcc4a23a1 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -0,0 +1,148 @@ +package org.apache.spark.util.collection + +import scala.collection.mutable.HashSet +import org.scalatest.FunSuite + +class OpenHashMapSuite extends FunSuite { + + test("initialization") { + val goodMap1 = new OpenHashMap[String, Int](1) + assert(goodMap1.size === 0) + val goodMap2 = new OpenHashMap[String, Int](255) + assert(goodMap2.size === 0) + val goodMap3 = new OpenHashMap[String, String](256) + assert(goodMap3.size === 0) + intercept[IllegalArgumentException] { + new OpenHashMap[String, Int](1 << 30) // Invalid map size: bigger than 2^29 + } + intercept[IllegalArgumentException] { + new OpenHashMap[String, Int](-1) + } + intercept[IllegalArgumentException] { + new OpenHashMap[String, String](0) + } + } + + test("primitive value") { + val map = new OpenHashMap[String, Int] + + for (i <- 1 to 1000) { + map(i.toString) = i + assert(map(i.toString) === i) + } + + assert(map.size === 1000) + assert(map(null) === 0) + + map(null) = -1 + assert(map.size === 1001) + assert(map(null) === -1) + + for (i <- 1 to 1000) { + assert(map(i.toString) === i) + } + + // Test iterator + val set = new HashSet[(String, Int)] + for ((k, v) <- map) { + set.add((k, v)) + } + val expected = (1 to 1000).map(x => (x.toString, x)) :+ (null.asInstanceOf[String], -1) + assert(set === expected.toSet) + } + + test("non-primitive value") { + val map = new OpenHashMap[String, String] + + for (i <- 1 to 1000) { + map(i.toString) = i.toString + assert(map(i.toString) === i.toString) + } + + assert(map.size === 1000) + assert(map(null) === null) + + map(null) = "-1" + assert(map.size === 1001) + assert(map(null) === "-1") + + for (i <- 1 to 1000) { + assert(map(i.toString) === i.toString) + } + + // Test iterator + val set = new HashSet[(String, String)] + for ((k, v) <- map) { + set.add((k, v)) + } + val expected = (1 to 1000).map(_.toString).map(x => (x, x)) :+ (null.asInstanceOf[String], "-1") + assert(set === expected.toSet) + } + + test("null keys") { + val map = new OpenHashMap[String, String]() + for (i <- 1 to 100) { + map(i.toString) = i.toString + } + assert(map.size === 100) + assert(map(null) === null) + map(null) = "hello" + assert(map.size === 101) + assert(map(null) === "hello") + } + + test("null values") { + val map = new OpenHashMap[String, String]() + for (i <- 1 to 100) { + map(i.toString) = null + } + assert(map.size === 100) + assert(map("1") === null) + assert(map(null) === null) + assert(map.size === 100) + map(null) = null + assert(map.size === 101) + assert(map(null) === null) + } + + test("changeValue") { + val map = new OpenHashMap[String, String]() + for (i <- 1 to 100) { + map(i.toString) = i.toString + } + assert(map.size === 100) + for (i <- 1 to 100) { + val res = map.changeValue(i.toString, { assert(false); "" }, v => { + assert(v === i.toString) + v + "!" + }) + assert(res === i + "!") + } + // Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a + // bug where changeValue would return the wrong result when the map grew on that insert + for (i <- 101 to 400) { + val res = map.changeValue(i.toString, { i + "!" }, v => { assert(false); v }) + assert(res === i + "!") + } + assert(map.size === 400) + assert(map(null) === null) + map.changeValue(null, { "null!" }, v => { assert(false); v }) + assert(map.size === 401) + map.changeValue(null, { assert(false); "" }, v => { + assert(v === "null!") + "null!!" + }) + assert(map.size === 401) + } + + test("inserting in capacity-1 map") { + val map = new OpenHashMap[String, String](1) + for (i <- 1 to 100) { + map(i.toString) = i.toString + } + assert(map.size === 100) + for (i <- 1 to 100) { + assert(map(i.toString) === i.toString) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..4e11e8a628b44e3dffa1b076263cfc3696eea438 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -0,0 +1,145 @@ +package org.apache.spark.util.collection + +import org.scalatest.FunSuite + + +class OpenHashSetSuite extends FunSuite { + + test("primitive int") { + val set = new OpenHashSet[Int] + assert(set.size === 0) + assert(!set.contains(10)) + assert(!set.contains(50)) + assert(!set.contains(999)) + assert(!set.contains(10000)) + + set.add(10) + assert(set.contains(10)) + assert(!set.contains(50)) + assert(!set.contains(999)) + assert(!set.contains(10000)) + + set.add(50) + assert(set.size === 2) + assert(set.contains(10)) + assert(set.contains(50)) + assert(!set.contains(999)) + assert(!set.contains(10000)) + + set.add(999) + assert(set.size === 3) + assert(set.contains(10)) + assert(set.contains(50)) + assert(set.contains(999)) + assert(!set.contains(10000)) + + set.add(50) + assert(set.size === 3) + assert(set.contains(10)) + assert(set.contains(50)) + assert(set.contains(999)) + assert(!set.contains(10000)) + } + + test("primitive long") { + val set = new OpenHashSet[Long] + assert(set.size === 0) + assert(!set.contains(10L)) + assert(!set.contains(50L)) + assert(!set.contains(999L)) + assert(!set.contains(10000L)) + + set.add(10L) + assert(set.size === 1) + assert(set.contains(10L)) + assert(!set.contains(50L)) + assert(!set.contains(999L)) + assert(!set.contains(10000L)) + + set.add(50L) + assert(set.size === 2) + assert(set.contains(10L)) + assert(set.contains(50L)) + assert(!set.contains(999L)) + assert(!set.contains(10000L)) + + set.add(999L) + assert(set.size === 3) + assert(set.contains(10L)) + assert(set.contains(50L)) + assert(set.contains(999L)) + assert(!set.contains(10000L)) + + set.add(50L) + assert(set.size === 3) + assert(set.contains(10L)) + assert(set.contains(50L)) + assert(set.contains(999L)) + assert(!set.contains(10000L)) + } + + test("non-primitive") { + val set = new OpenHashSet[String] + assert(set.size === 0) + assert(!set.contains(10.toString)) + assert(!set.contains(50.toString)) + assert(!set.contains(999.toString)) + assert(!set.contains(10000.toString)) + + set.add(10.toString) + assert(set.size === 1) + assert(set.contains(10.toString)) + assert(!set.contains(50.toString)) + assert(!set.contains(999.toString)) + assert(!set.contains(10000.toString)) + + set.add(50.toString) + assert(set.size === 2) + assert(set.contains(10.toString)) + assert(set.contains(50.toString)) + assert(!set.contains(999.toString)) + assert(!set.contains(10000.toString)) + + set.add(999.toString) + assert(set.size === 3) + assert(set.contains(10.toString)) + assert(set.contains(50.toString)) + assert(set.contains(999.toString)) + assert(!set.contains(10000.toString)) + + set.add(50.toString) + assert(set.size === 3) + assert(set.contains(10.toString)) + assert(set.contains(50.toString)) + assert(set.contains(999.toString)) + assert(!set.contains(10000.toString)) + } + + test("non-primitive set growth") { + val set = new OpenHashSet[String] + for (i <- 1 to 1000) { + set.add(i.toString) + } + assert(set.size === 1000) + assert(set.capacity > 1000) + for (i <- 1 to 100) { + set.add(i.toString) + } + assert(set.size === 1000) + assert(set.capacity > 1000) + } + + test("primitive set growth") { + val set = new OpenHashSet[Long] + for (i <- 1 to 1000) { + set.add(i.toLong) + } + assert(set.size === 1000) + assert(set.capacity > 1000) + for (i <- 1 to 100) { + set.add(i.toLong) + } + assert(set.size === 1000) + assert(set.capacity > 1000) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..dfd6aed2c4bccf7f1d9a25690ce0c6be41097678 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala @@ -0,0 +1,90 @@ +package org.apache.spark.util.collection + +import scala.collection.mutable.HashSet +import org.scalatest.FunSuite + +class PrimitiveKeyOpenHashSetSuite extends FunSuite { + + test("initialization") { + val goodMap1 = new PrimitiveKeyOpenHashMap[Int, Int](1) + assert(goodMap1.size === 0) + val goodMap2 = new PrimitiveKeyOpenHashMap[Int, Int](255) + assert(goodMap2.size === 0) + val goodMap3 = new PrimitiveKeyOpenHashMap[Int, Int](256) + assert(goodMap3.size === 0) + intercept[IllegalArgumentException] { + new PrimitiveKeyOpenHashMap[Int, Int](1 << 30) // Invalid map size: bigger than 2^29 + } + intercept[IllegalArgumentException] { + new PrimitiveKeyOpenHashMap[Int, Int](-1) + } + intercept[IllegalArgumentException] { + new PrimitiveKeyOpenHashMap[Int, Int](0) + } + } + + test("basic operations") { + val longBase = 1000000L + val map = new PrimitiveKeyOpenHashMap[Long, Int] + + for (i <- 1 to 1000) { + map(i + longBase) = i + assert(map(i + longBase) === i) + } + + assert(map.size === 1000) + + for (i <- 1 to 1000) { + assert(map(i + longBase) === i) + } + + // Test iterator + val set = new HashSet[(Long, Int)] + for ((k, v) <- map) { + set.add((k, v)) + } + assert(set === (1 to 1000).map(x => (x + longBase, x)).toSet) + } + + test("null values") { + val map = new PrimitiveKeyOpenHashMap[Long, String]() + for (i <- 1 to 100) { + map(i.toLong) = null + } + assert(map.size === 100) + assert(map(1.toLong) === null) + } + + test("changeValue") { + val map = new PrimitiveKeyOpenHashMap[Long, String]() + for (i <- 1 to 100) { + map(i.toLong) = i.toString + } + assert(map.size === 100) + for (i <- 1 to 100) { + val res = map.changeValue(i.toLong, { assert(false); "" }, v => { + assert(v === i.toString) + v + "!" + }) + assert(res === i + "!") + } + // Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a + // bug where changeValue would return the wrong result when the map grew on that insert + for (i <- 101 to 400) { + val res = map.changeValue(i.toLong, { i + "!" }, v => { assert(false); v }) + assert(res === i + "!") + } + assert(map.size === 400) + } + + test("inserting in capacity-1 map") { + val map = new PrimitiveKeyOpenHashMap[Long, String](1) + for (i <- 1 to 100) { + map(i.toLong) = i.toString + } + assert(map.size === 100) + for (i <- 1 to 100) { + assert(map(i.toLong) === i.toString) + } + } +}