Skip to content
Snippets Groups Projects
Commit 0e2109dd authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Merge pull request #204 from rxin/hash

OpenHashSet fixes

Incorporated ideas from pull request #200.
- Use Murmur Hash 3 finalization step to scramble the bits of HashCode
  instead of the simpler version in java.util.HashMap; the latter one
  had trouble with ranges of consecutive integers. Murmur Hash 3 is used
  by fastutil.
- Don't check keys for equality when re-inserting due to growing the
  table; the keys will already be unique.
- Remember the grow threshold instead of recomputing it on each insert

Also added unit tests for size estimation for specialized hash sets and maps.
parents c46067f0 466fd064
No related branches found
No related tags found
No related merge requests found
......@@ -79,6 +79,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
protected var _capacity = nextPowerOf2(initialCapacity)
protected var _mask = _capacity - 1
protected var _size = 0
protected var _growThreshold = (loadFactor * _capacity).toInt
protected var _bitset = new BitSet(_capacity)
......@@ -115,7 +116,29 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
* @return The position where the key is placed, plus the highest order bit is set if the key
* exists previously.
*/
def addWithoutResize(k: T): Int = putInto(_bitset, _data, k)
def addWithoutResize(k: T): Int = {
var pos = hashcode(hasher.hash(k)) & _mask
var i = 1
while (true) {
if (!_bitset.get(pos)) {
// This is a new key.
_data(pos) = k
_bitset.set(pos)
_size += 1
return pos | NONEXISTENCE_MASK
} else if (_data(pos) == k) {
// Found an existing key.
return pos
} else {
val delta = i
pos = (pos + delta) & _mask
i += 1
}
}
// Never reached here
assert(INVALID_POS != INVALID_POS)
INVALID_POS
}
/**
* Rehash the set if it is overloaded.
......@@ -126,7 +149,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
* to a new position (in the new data array).
*/
def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
if (_size > loadFactor * _capacity) {
if (_size > _growThreshold) {
rehash(k, allocateFunc, moveFunc)
}
}
......@@ -160,37 +183,6 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
*/
def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos)
/**
* Put an entry into the set. Return the position where the key is placed. In addition, the
* highest bit in the returned position is set if the key exists prior to this put.
*
* This function assumes the data array has at least one empty slot.
*/
private def putInto(bitset: BitSet, data: Array[T], k: T): Int = {
val mask = data.length - 1
var pos = hashcode(hasher.hash(k)) & mask
var i = 1
while (true) {
if (!bitset.get(pos)) {
// This is a new key.
data(pos) = k
bitset.set(pos)
_size += 1
return pos | NONEXISTENCE_MASK
} else if (data(pos) == k) {
// Found an existing key.
return pos
} else {
val delta = i
pos = (pos + delta) & mask
i += 1
}
}
// Never reached here
assert(INVALID_POS != INVALID_POS)
INVALID_POS
}
/**
* Double the table's size and re-hash everything. We are not really using k, but it is declared
* so Scala compiler can specialize this method (which leads to calling the specialized version
......@@ -204,34 +196,49 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
*/
private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
val newCapacity = _capacity * 2
require(newCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
allocateFunc(newCapacity)
val newData = new Array[T](newCapacity)
val newBitset = new BitSet(newCapacity)
var pos = 0
_size = 0
while (pos < _capacity) {
if (_bitset.get(pos)) {
val newPos = putInto(newBitset, newData, _data(pos))
moveFunc(pos, newPos & POSITION_MASK)
val newData = new Array[T](newCapacity)
val newMask = newCapacity - 1
var oldPos = 0
while (oldPos < capacity) {
if (_bitset.get(oldPos)) {
val key = _data(oldPos)
var newPos = hashcode(hasher.hash(key)) & newMask
var i = 1
var keepGoing = true
// No need to check for equality here when we insert so this has one less if branch than
// the similar code path in addWithoutResize.
while (keepGoing) {
if (!newBitset.get(newPos)) {
// Inserting the key at newPos
newData(newPos) = key
newBitset.set(newPos)
moveFunc(oldPos, newPos)
keepGoing = false
} else {
val delta = i
newPos = (newPos + delta) & newMask
i += 1
}
}
}
pos += 1
oldPos += 1
}
_bitset = newBitset
_data = newData
_capacity = newCapacity
_mask = newCapacity - 1
_mask = newMask
_growThreshold = (loadFactor * newCapacity).toInt
}
/**
* Re-hash a value to deal better with hash functions that don't differ
* in the lower bits, similar to java.util.HashMap
* Re-hash a value to deal better with hash functions that don't differ in the lower bits.
* We use the Murmur Hash 3 finalization step that's also used in fastutil.
*/
private def hashcode(h: Int): Int = {
val r = h ^ (h >>> 20) ^ (h >>> 12)
r ^ (r >>> 7) ^ (r >>> 4)
}
private def hashcode(h: Int): Int = it.unimi.dsi.fastutil.HashCommon.murmurHash3(h)
private def nextPowerOf2(n: Int): Int = {
val highBit = Integer.highestOneBit(n)
......
......@@ -2,8 +2,20 @@ package org.apache.spark.util.collection
import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
class OpenHashMapSuite extends FunSuite {
import org.scalatest.matchers.ShouldMatchers
import org.apache.spark.util.SizeEstimator
class OpenHashMapSuite extends FunSuite with ShouldMatchers {
test("size for specialized, primitive value (int)") {
val capacity = 1024
val map = new OpenHashMap[String, Int](capacity)
val actualSize = SizeEstimator.estimate(map)
// 64 bit for pointers, 32 bit for ints, and 1 bit for the bitset.
val expectedSize = capacity * (64 + 32 + 1) / 8
// Make sure we are not allocating a significant amount of memory beyond our expected.
actualSize should be <= (expectedSize * 1.1).toLong
}
test("initialization") {
val goodMap1 = new OpenHashMap[String, Int](1)
......
package org.apache.spark.util.collection
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
import org.apache.spark.util.SizeEstimator
class OpenHashSetSuite extends FunSuite {
class OpenHashSetSuite extends FunSuite with ShouldMatchers {
test("size for specialized, primitive int") {
val loadFactor = 0.7
val set = new OpenHashSet[Int](64, loadFactor)
for (i <- 0 until 1024) {
set.add(i)
}
assert(set.size === 1024)
assert(set.capacity > 1024)
val actualSize = SizeEstimator.estimate(set)
// 32 bits for the ints + 1 bit for the bitset
val expectedSize = set.capacity * (32 + 1) / 8
// Make sure we are not allocating a significant amount of memory beyond our expected.
actualSize should be <= (expectedSize * 1.1).toLong
}
test("primitive int") {
val set = new OpenHashSet[Int]
......
......@@ -2,8 +2,20 @@ package org.apache.spark.util.collection
import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
import org.apache.spark.util.SizeEstimator
class PrimitiveKeyOpenHashSetSuite extends FunSuite {
class PrimitiveKeyOpenHashMapSuite extends FunSuite with ShouldMatchers {
test("size for specialized, primitive key, value (int, int)") {
val capacity = 1024
val map = new PrimitiveKeyOpenHashMap[Int, Int](capacity)
val actualSize = SizeEstimator.estimate(map)
// 32 bit for keys, 32 bit for values, and 1 bit for the bitset.
val expectedSize = capacity * (32 + 32 + 1) / 8
// Make sure we are not allocating a significant amount of memory beyond our expected.
actualSize should be <= (expectedSize * 1.1).toLong
}
test("initialization") {
val goodMap1 = new PrimitiveKeyOpenHashMap[Int, Int](1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment