From eb5f8a3f977688beb2f068050d8fabe7e15141d3 Mon Sep 17 00:00:00 2001 From: Reynold Xin <rxin@apache.org> Date: Sun, 3 Nov 2013 18:04:21 -0800 Subject: [PATCH] Code review feedback. --- .../apache/spark/util/collection/BitSet.scala | 4 +- .../spark/util/collection/OpenHashMap.scala | 2 +- .../spark/util/collection/OpenHashSet.scala | 20 +++-- .../collection/PrimitiveKeyOpenHashMap.scala | 2 +- .../util/collection/OpenHashMapSuite.scala | 16 ++-- .../util/collection/OpenHashSetSuite.scala | 73 ++++++++++++++++++- .../PrimitiveKeyOpenHashSetSuite.scala | 8 +- 7 files changed, 100 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala index 6604ec738c..a1a452315d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -45,7 +45,7 @@ class BitSet(numBits: Int) { */ def get(index: Int): Boolean = { val bitmask = 1L << (index & 0x3f) // mod 64 and shift - (words(index >>> 6) & bitmask) != 0 // div by 64 and mask + (words(index >> 6) & bitmask) != 0 // div by 64 and mask } /** Return the number of bits set to true in this BitSet. */ @@ -99,5 +99,5 @@ class BitSet(numBits: Int) { } /** Return the number of longs it would take to hold numBits. */ - private def bit2words(numBits: Int) = ((numBits - 1) >>> 6) + 1 + private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1 } diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala index ed117b2abf..80545c9688 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala @@ -92,7 +92,7 @@ class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V: nullValue } else { val pos = _keySet.addWithoutResize(k) - if ((pos & OpenHashSet.EXISTENCE_MASK) != 0) { + if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { val newValue = defaultValue _values(pos & OpenHashSet.POSITION_MASK) = newValue _keySet.rehashIfNeeded(k, grow, move) diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index e98a93dc2a..4592e4f939 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -43,6 +43,8 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") require(initialCapacity >= 1, "Invalid initial capacity") + require(loadFactor < 1.0, "Load factor must be less than 1.0") + require(loadFactor > 0.0, "Load factor must be greater than 0.0") import OpenHashSet._ @@ -119,8 +121,8 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( * Rehash the set if it is overloaded. * @param k A parameter unused in the function, but to force the Scala compiler to specialize * this method. - * @param allocateFunc Closure invoked when we are allocating a new, larger array. - * @param moveFunc Closure invoked when we move the key from one position (in the old data array) + * @param allocateFunc Callback invoked when we are allocating a new, larger array. + * @param moveFunc Callback invoked when we move the key from one position (in the old data array) * to a new position (in the new data array). */ def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { @@ -129,7 +131,9 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( } } - /** Return the position of the element in the underlying array. */ + /** + * Return the position of the element in the underlying array, or INVALID_POS if it is not found. + */ def getPos(k: T): Int = { var pos = hashcode(hasher.hash(k)) & _mask var i = 1 @@ -172,7 +176,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( data(pos) = k bitset.set(pos) _size += 1 - return pos | EXISTENCE_MASK + return pos | NONEXISTENCE_MASK } else if (data(pos) == k) { // Found an existing key. return pos @@ -194,8 +198,8 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( * * @param k A parameter unused in the function, but to force the Scala compiler to specialize * this method. - * @param allocateFunc Closure invoked when we are allocating a new, larger array. - * @param moveFunc Closure invoked when we move the key from one position (in the old data array) + * @param allocateFunc Callback invoked when we are allocating a new, larger array. + * @param moveFunc Callback invoked when we move the key from one position (in the old data array) * to a new position (in the new data array). */ private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { @@ -203,7 +207,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( require(newCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") allocateFunc(newCapacity) - val newData = classManifest[T].newArray(newCapacity) + val newData = new Array[T](newCapacity) val newBitset = new BitSet(newCapacity) var pos = 0 _size = 0 @@ -240,7 +244,7 @@ private[spark] object OpenHashSet { val INVALID_POS = -1 - val EXISTENCE_MASK = 0x80000000 + val NONEXISTENCE_MASK = 0x80000000 val POSITION_MASK = 0xEFFFFFF /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala index e8f28ecdd7..4adf9cfb76 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala @@ -69,7 +69,7 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest, */ def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = { val pos = _keySet.addWithoutResize(k) - if ((pos & OpenHashSet.EXISTENCE_MASK) != 0) { + if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { val newValue = defaultValue _values(pos & OpenHashSet.POSITION_MASK) = newValue _keySet.rehashIfNeeded(k, grow, move) diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index 5e74ca1f7e..ca3f684668 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -82,7 +82,7 @@ class OpenHashMapSuite extends FunSuite { test("null keys") { val map = new OpenHashMap[String, String]() for (i <- 1 to 100) { - map("" + i) = "" + i + map(i.toString) = i.toString } assert(map.size === 100) assert(map(null) === null) @@ -94,7 +94,7 @@ class OpenHashMapSuite extends FunSuite { test("null values") { val map = new OpenHashMap[String, String]() for (i <- 1 to 100) { - map("" + i) = null + map(i.toString) = null } assert(map.size === 100) assert(map("1") === null) @@ -108,12 +108,12 @@ class OpenHashMapSuite extends FunSuite { test("changeValue") { val map = new OpenHashMap[String, String]() for (i <- 1 to 100) { - map("" + i) = "" + i + map(i.toString) = i.toString } assert(map.size === 100) for (i <- 1 to 100) { - val res = map.changeValue("" + i, { assert(false); "" }, v => { - assert(v === "" + i) + val res = map.changeValue(i.toString, { assert(false); "" }, v => { + assert(v === i.toString) v + "!" }) assert(res === i + "!") @@ -121,7 +121,7 @@ class OpenHashMapSuite extends FunSuite { // Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a // bug where changeValue would return the wrong result when the map grew on that insert for (i <- 101 to 400) { - val res = map.changeValue("" + i, { i + "!" }, v => { assert(false); v }) + val res = map.changeValue(i.toString, { i + "!" }, v => { assert(false); v }) assert(res === i + "!") } assert(map.size === 400) @@ -138,11 +138,11 @@ class OpenHashMapSuite extends FunSuite { test("inserting in capacity-1 map") { val map = new OpenHashMap[String, String](1) for (i <- 1 to 100) { - map("" + i) = "" + i + map(i.toString) = i.toString } assert(map.size === 100) for (i <- 1 to 100) { - assert(map("" + i) === "" + i) + assert(map(i.toString) === i.toString) } } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index 40049e8475..4e11e8a628 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -8,40 +8,111 @@ class OpenHashSetSuite extends FunSuite { test("primitive int") { val set = new OpenHashSet[Int] assert(set.size === 0) + assert(!set.contains(10)) + assert(!set.contains(50)) + assert(!set.contains(999)) + assert(!set.contains(10000)) + set.add(10) - assert(set.size === 1) + assert(set.contains(10)) + assert(!set.contains(50)) + assert(!set.contains(999)) + assert(!set.contains(10000)) + set.add(50) assert(set.size === 2) + assert(set.contains(10)) + assert(set.contains(50)) + assert(!set.contains(999)) + assert(!set.contains(10000)) + set.add(999) assert(set.size === 3) + assert(set.contains(10)) + assert(set.contains(50)) + assert(set.contains(999)) + assert(!set.contains(10000)) + set.add(50) assert(set.size === 3) + assert(set.contains(10)) + assert(set.contains(50)) + assert(set.contains(999)) + assert(!set.contains(10000)) } test("primitive long") { val set = new OpenHashSet[Long] assert(set.size === 0) + assert(!set.contains(10L)) + assert(!set.contains(50L)) + assert(!set.contains(999L)) + assert(!set.contains(10000L)) + set.add(10L) assert(set.size === 1) + assert(set.contains(10L)) + assert(!set.contains(50L)) + assert(!set.contains(999L)) + assert(!set.contains(10000L)) + set.add(50L) assert(set.size === 2) + assert(set.contains(10L)) + assert(set.contains(50L)) + assert(!set.contains(999L)) + assert(!set.contains(10000L)) + set.add(999L) assert(set.size === 3) + assert(set.contains(10L)) + assert(set.contains(50L)) + assert(set.contains(999L)) + assert(!set.contains(10000L)) + set.add(50L) assert(set.size === 3) + assert(set.contains(10L)) + assert(set.contains(50L)) + assert(set.contains(999L)) + assert(!set.contains(10000L)) } test("non-primitive") { val set = new OpenHashSet[String] assert(set.size === 0) + assert(!set.contains(10.toString)) + assert(!set.contains(50.toString)) + assert(!set.contains(999.toString)) + assert(!set.contains(10000.toString)) + set.add(10.toString) assert(set.size === 1) + assert(set.contains(10.toString)) + assert(!set.contains(50.toString)) + assert(!set.contains(999.toString)) + assert(!set.contains(10000.toString)) + set.add(50.toString) assert(set.size === 2) + assert(set.contains(10.toString)) + assert(set.contains(50.toString)) + assert(!set.contains(999.toString)) + assert(!set.contains(10000.toString)) + set.add(999.toString) assert(set.size === 3) + assert(set.contains(10.toString)) + assert(set.contains(50.toString)) + assert(set.contains(999.toString)) + assert(!set.contains(10000.toString)) + set.add(50.toString) assert(set.size === 3) + assert(set.contains(10.toString)) + assert(set.contains(50.toString)) + assert(set.contains(999.toString)) + assert(!set.contains(10000.toString)) } test("non-primitive set growth") { diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala index dc7f6cb023..dfd6aed2c4 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala @@ -58,12 +58,12 @@ class PrimitiveKeyOpenHashSetSuite extends FunSuite { test("changeValue") { val map = new PrimitiveKeyOpenHashMap[Long, String]() for (i <- 1 to 100) { - map(i.toLong) = "" + i + map(i.toLong) = i.toString } assert(map.size === 100) for (i <- 1 to 100) { val res = map.changeValue(i.toLong, { assert(false); "" }, v => { - assert(v === "" + i) + assert(v === i.toString) v + "!" }) assert(res === i + "!") @@ -80,11 +80,11 @@ class PrimitiveKeyOpenHashSetSuite extends FunSuite { test("inserting in capacity-1 map") { val map = new PrimitiveKeyOpenHashMap[Long, String](1) for (i <- 1 to 100) { - map(i.toLong) = "" + i + map(i.toLong) = i.toString } assert(map.size === 100) for (i <- 1 to 100) { - assert(map(i.toLong) === "" + i) + assert(map(i.toLong) === i.toString) } } } -- GitLab