diff --git a/core/pom.xml b/core/pom.xml index 459ef66712c36a289506a1aebec86327f3881035..2dfb00d7ecf260857df323f353416ce32bbb83d2 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -95,6 +95,11 @@ <artifactId>spark-network-shuffle_${scala.binary.version}</artifactId> <version>${project.version}</version> </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-unsafe_${scala.binary.version}</artifactId> + <version>${project.version}</version> + </dependency> <dependency> <groupId>net.java.dev.jets3t</groupId> <artifactId>jets3t</artifactId> diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 959aefabd8de48ca79412ce7e7ee0cc5d574bb09..0c4d28f786edda1599b2d32e68fbab10ba0b28be 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -40,6 +40,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator} import org.apache.spark.util.{RpcUtils, Utils} /** @@ -69,6 +70,7 @@ class SparkEnv ( val sparkFilesDir: String, val metricsSystem: MetricsSystem, val shuffleMemoryManager: ShuffleMemoryManager, + val executorMemoryManager: ExecutorMemoryManager, val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { @@ -382,6 +384,15 @@ object SparkEnv extends Logging { new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator)) outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef) + val executorMemoryManager: ExecutorMemoryManager = { + val allocator = if (conf.getBoolean("spark.unsafe.offHeap", false)) { + MemoryAllocator.UNSAFE + } else { + MemoryAllocator.HEAP + } + new ExecutorMemoryManager(allocator) + } + val envInstance = new SparkEnv( executorId, rpcEnv, @@ -398,6 +409,7 @@ object SparkEnv extends Logging { sparkFilesDir, metricsSystem, shuffleMemoryManager, + executorMemoryManager, outputCommitCoordinator, conf) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 7d7fe1a44631360dc95eead4528e4bc233de4196..d09e17dea0911a3f4e32ac1f86f1bd5b704f5112 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -21,6 +21,7 @@ import java.io.Serializable import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.TaskCompletionListener @@ -133,4 +134,9 @@ abstract class TaskContext extends Serializable { /** ::DeveloperApi:: */ @DeveloperApi def taskMetrics(): TaskMetrics + + /** + * Returns the manager for this task's managed memory. + */ + private[spark] def taskMemoryManager(): TaskMemoryManager } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 337c8e4ebebcdf895013560c31fe4c606929dee6..b4d572cb5231353c3a7f12ac3636802f02c4fbc9 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -18,6 +18,7 @@ package org.apache.spark import org.apache.spark.executor.TaskMetrics +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} import scala.collection.mutable.ArrayBuffer @@ -27,6 +28,7 @@ private[spark] class TaskContextImpl( val partitionId: Int, override val taskAttemptId: Long, override val attemptNumber: Int, + override val taskMemoryManager: TaskMemoryManager, val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty) extends TaskContext diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index f57e215c3f2ed433c83804c37301cafc3dcda01e..dd1c48e6cb953b692ddcedf83334a84283a45e49 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -32,6 +32,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util._ /** @@ -178,6 +179,7 @@ private[spark] class Executor( } override def run(): Unit = { + val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager) val deserializeStartTime = System.currentTimeMillis() Thread.currentThread.setContextClassLoader(replClassLoader) val ser = env.closureSerializer.newInstance() @@ -190,6 +192,7 @@ private[spark] class Executor( val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) updateDependencies(taskFiles, taskJars) task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) + task.setTaskMemoryManager(taskMemoryManager) // If this task has been killed before we deserialized it, let's quit now. Otherwise, // continue executing the task. @@ -206,7 +209,21 @@ private[spark] class Executor( // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() - val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber) + val value = try { + task.run(taskAttemptId = taskId, attemptNumber = attemptNumber) + } finally { + // Note: this memory freeing logic is duplicated in DAGScheduler.runLocallyWithinThread; + // when changing this, make sure to update both copies. + val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() + if (freedMemory > 0) { + val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId" + if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) { + throw new SparkException(errMsg) + } else { + logError(errMsg) + } + } + } val taskFinish = System.currentTimeMillis() // If the task has been killed, let's fail it. diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8c4bff4e83afc6b1ef1b280e51b614b84ba9240f..b7901c06a16e0d273e581a2d102c0f35f5401710 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -34,6 +34,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage._ +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -643,8 +644,15 @@ class DAGScheduler( try { val rdd = job.finalStage.rdd val split = rdd.partitions(job.partitions(0)) - val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0, - attemptNumber = 0, runningLocally = true) + val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager) + val taskContext = + new TaskContextImpl( + job.finalStage.id, + job.partitions(0), + taskAttemptId = 0, + attemptNumber = 0, + taskMemoryManager = taskMemoryManager, + runningLocally = true) TaskContext.setTaskContext(taskContext) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) @@ -652,6 +660,16 @@ class DAGScheduler( } finally { taskContext.markTaskCompleted() TaskContext.unset() + // Note: this memory freeing logic is duplicated in Executor.run(); when changing this, + // make sure to update both copies. + val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() + if (freedMemory > 0) { + if (sc.getConf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) { + throw new SparkException(s"Managed memory leak detected; size = $freedMemory bytes") + } else { + logError(s"Managed memory leak detected; size = $freedMemory bytes") + } + } } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index b09b19e2ac9e7061edfccb6742033d3be98cd038..586d1e06204c1cc0dd15f7b6ae9561937af1d669 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -25,6 +25,7 @@ import scala.collection.mutable.HashMap import org.apache.spark.{TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.ByteBufferInputStream import org.apache.spark.util.Utils @@ -52,8 +53,13 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex * @return the result of the task */ final def run(taskAttemptId: Long, attemptNumber: Int): T = { - context = new TaskContextImpl(stageId = stageId, partitionId = partitionId, - taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false) + context = new TaskContextImpl( + stageId = stageId, + partitionId = partitionId, + taskAttemptId = taskAttemptId, + attemptNumber = attemptNumber, + taskMemoryManager = taskMemoryManager, + runningLocally = false) TaskContext.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) taskThread = Thread.currentThread() @@ -68,6 +74,12 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex } } + private var taskMemoryManager: TaskMemoryManager = _ + + def setTaskMemoryManager(taskMemoryManager: TaskMemoryManager): Unit = { + this.taskMemoryManager = taskMemoryManager + } + def runTask(context: TaskContext): T def preferredLocations: Seq[TaskLocation] = Nil diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 8a4f2a08fe701cdd8ba5f3d4e3f62713a5491544..34ac9361d46c6942bd9567e59fe223afb5efe107 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -1009,7 +1009,7 @@ public class JavaAPISuite implements Serializable { @Test public void iterator() { JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContextImpl(0, 0, 0L, 0, false, new TaskMetrics()); + TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics()); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 70529d9216591ec7c896e962d2bb1fd61c05c775..668ddf9f5f0a90ad26560b4b1283e077f68ff8b3 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -65,7 +65,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf // in blockManager.put is a losing battle. You have been warned. blockManager = sc.env.blockManager cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0, null) val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) val getValue = blockManager.get(RDDBlockId(rdd.id, split.index)) assert(computeValue.toList === List(1, 2, 3, 4)) @@ -77,7 +77,7 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result)) - val context = new TaskContextImpl(0, 0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0, null) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -86,14 +86,14 @@ class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAf // Local computation should not persist the resulting value, so don't expect a put(). when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None) - val context = new TaskContextImpl(0, 0, 0, 0, true) + val context = new TaskContextImpl(0, 0, 0, 0, null, true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } test("verify task metrics updated correctly") { cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0, null) cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index aea76c1adcc099327b8304803ead5ec0d004e0ca..85eb2a1d07ba451ccb01dc384073b0c655031cc1 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -176,7 +176,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext { } val hadoopPart1 = generateFakeHadoopPartition() val pipedRdd = new PipedRDD(nums, "printenv " + varName) - val tContext = new TaskContextImpl(0, 0, 0, 0) + val tContext = new TaskContextImpl(0, 0, 0, 0, null) val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray assert(arr(0) == "/some/path") diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 057e226916027c1dde94cb4fbd1718526e73f93e..83ae8701243e5cdd49e13c2d238349f6cfeaac19 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -51,7 +51,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte } test("all TaskCompletionListeners should be called even if some fail") { - val context = new TaskContextImpl(0, 0, 0, 0) + val context = new TaskContextImpl(0, 0, 0, 0, null) val listener = mock(classOf[TaskCompletionListener]) context.addTaskCompletionListener(_ => throw new Exception("blah")) context.addTaskCompletionListener(listener) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 37b593b2c5f7969aab1b21a000944c46d432f5d7..2080c432d77db7fc5d9322351bef8f16571a0cdf 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -89,7 +89,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContextImpl(0, 0, 0, 0), + new TaskContextImpl(0, 0, 0, 0, null), transfer, blockManager, blocksByAddress, @@ -154,7 +154,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0) + val taskContext = new TaskContextImpl(0, 0, 0, 0, null) val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, @@ -217,7 +217,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0) + val taskContext = new TaskContextImpl(0, 0, 0, 0, null) val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, diff --git a/pom.xml b/pom.xml index 928f5d0f5efad8d3a37fbede6395b173f474fdba..c85c5feeaf383d9a35f847ce756893d55800e1c8 100644 --- a/pom.xml +++ b/pom.xml @@ -97,6 +97,7 @@ <module>sql/catalyst</module> <module>sql/core</module> <module>sql/hive</module> + <module>unsafe</module> <module>assembly</module> <module>external/twitter</module> <module>external/flume</module> @@ -1215,6 +1216,7 @@ <spark.ui.enabled>false</spark.ui.enabled> <spark.ui.showConsoleProgress>false</spark.ui.showConsoleProgress> <spark.driver.allowMultipleContexts>true</spark.driver.allowMultipleContexts> + <spark.unsafe.exceptionOnMemoryLeak>true</spark.unsafe.exceptionOnMemoryLeak> </systemProperties> <failIfNoTests>false</failIfNoTests> </configuration> diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 09b4976d10c26e636a0a6bfef142804c76fe6c7b..b7dbcd9bc562aee23c93037f7e0aacd6f270f9a7 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -34,11 +34,11 @@ object BuildCommons { val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka, - streamingMqtt, streamingTwitter, streamingZeromq, launcher) = + streamingMqtt, streamingTwitter, streamingZeromq, launcher, unsafe) = Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", - "streaming-zeromq", "launcher").map(ProjectRef(buildLocation, _)) + "streaming-zeromq", "launcher", "unsafe").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, yarnStable, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", @@ -159,7 +159,7 @@ object SparkBuild extends PomBuild { // TODO: Add Sql to mima checks // TODO: remove launcher from this list after 1.3. allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl, - networkCommon, networkShuffle, networkYarn, launcher).contains(x)).foreach { + networkCommon, networkShuffle, networkYarn, launcher, unsafe).contains(x)).foreach { x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) } @@ -496,6 +496,7 @@ object TestSettings { javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", + javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 3dea2ee76542f4442e22975bc33175170d4f96eb..5c322d032d47471a021603d1f028a3f5caffe75b 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -50,6 +50,11 @@ <artifactId>spark-core_${scala.binary.version}</artifactId> <version>${project.version}</version> </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-unsafe_${scala.binary.version}</artifactId> + <version>${project.version}</version> + </dependency> <dependency> <groupId>org.scalacheck</groupId> <artifactId>scalacheck_${scala.binary.version}</artifactId> diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java new file mode 100644 index 0000000000000000000000000000000000000000..299ff3728a6d99450bf4401b12f26d3e7a7e5953 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import java.util.Arrays; +import java.util.Iterator; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.map.BytesToBytesMap; +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +/** + * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width. + * + * This map supports a maximum of 2 billion keys. + */ +public final class UnsafeFixedWidthAggregationMap { + + /** + * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the + * map, we copy this buffer and use it as the value. + */ + private final long[] emptyAggregationBuffer; + + private final StructType aggregationBufferSchema; + + private final StructType groupingKeySchema; + + /** + * Encodes grouping keys as UnsafeRows. + */ + private final UnsafeRowConverter groupingKeyToUnsafeRowConverter; + + /** + * A hashmap which maps from opaque bytearray keys to bytearray values. + */ + private final BytesToBytesMap map; + + /** + * Re-used pointer to the current aggregation buffer + */ + private final UnsafeRow currentAggregationBuffer = new UnsafeRow(); + + /** + * Scratch space that is used when encoding grouping keys into UnsafeRow format. + * + * By default, this is a 1MB array, but it will grow as necessary in case larger keys are + * encountered. + */ + private long[] groupingKeyConversionScratchSpace = new long[1024 / 8]; + + private final boolean enablePerfMetrics; + + /** + * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema, + * false otherwise. + */ + public static boolean supportsGroupKeySchema(StructType schema) { + for (StructField field: schema.fields()) { + if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) { + return false; + } + } + return true; + } + + /** + * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given + * schema, false otherwise. + */ + public static boolean supportsAggregationBufferSchema(StructType schema) { + for (StructField field: schema.fields()) { + if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) { + return false; + } + } + return true; + } + + /** + * Create a new UnsafeFixedWidthAggregationMap. + * + * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function) + * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. + * @param groupingKeySchema the schema of the grouping key, used for row conversion. + * @param memoryManager the memory manager used to allocate our Unsafe memory structures. + * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). + * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact) + */ + public UnsafeFixedWidthAggregationMap( + Row emptyAggregationBuffer, + StructType aggregationBufferSchema, + StructType groupingKeySchema, + TaskMemoryManager memoryManager, + int initialCapacity, + boolean enablePerfMetrics) { + this.emptyAggregationBuffer = + convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema); + this.aggregationBufferSchema = aggregationBufferSchema; + this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema); + this.groupingKeySchema = groupingKeySchema; + this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics); + this.enablePerfMetrics = enablePerfMetrics; + } + + /** + * Convert a Java object row into an UnsafeRow, allocating it into a new long array. + */ + private static long[] convertToUnsafeRow(Row javaRow, StructType schema) { + final UnsafeRowConverter converter = new UnsafeRowConverter(schema); + final long[] unsafeRow = new long[converter.getSizeRequirement(javaRow)]; + final long writtenLength = + converter.writeRow(javaRow, unsafeRow, PlatformDependent.LONG_ARRAY_OFFSET); + assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!"; + return unsafeRow; + } + + /** + * Return the aggregation buffer for the current group. For efficiency, all calls to this method + * return the same object. + */ + public UnsafeRow getAggregationBuffer(Row groupingKey) { + final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey); + // Make sure that the buffer is large enough to hold the key. If it's not, grow it: + if (groupingKeySize > groupingKeyConversionScratchSpace.length) { + // This new array will be initially zero, so there's no need to zero it out here + groupingKeyConversionScratchSpace = new long[groupingKeySize]; + } else { + // Zero out the buffer that's used to hold the current row. This is necessary in order + // to ensure that rows hash properly, since garbage data from the previous row could + // otherwise end up as padding in this row. As a performance optimization, we only zero out + // the portion of the buffer that we'll actually write to. + Arrays.fill(groupingKeyConversionScratchSpace, 0, groupingKeySize, 0); + } + final long actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow( + groupingKey, + groupingKeyConversionScratchSpace, + PlatformDependent.LONG_ARRAY_OFFSET); + assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!"; + + // Probe our map using the serialized key + final BytesToBytesMap.Location loc = map.lookup( + groupingKeyConversionScratchSpace, + PlatformDependent.LONG_ARRAY_OFFSET, + groupingKeySize); + if (!loc.isDefined()) { + // This is the first time that we've seen this grouping key, so we'll insert a copy of the + // empty aggregation buffer into the map: + loc.putNewKey( + groupingKeyConversionScratchSpace, + PlatformDependent.LONG_ARRAY_OFFSET, + groupingKeySize, + emptyAggregationBuffer, + PlatformDependent.LONG_ARRAY_OFFSET, + emptyAggregationBuffer.length + ); + } + + // Reset the pointer to point to the value that we just stored or looked up: + final MemoryLocation address = loc.getValueAddress(); + currentAggregationBuffer.pointTo( + address.getBaseObject(), + address.getBaseOffset(), + aggregationBufferSchema.length(), + aggregationBufferSchema + ); + return currentAggregationBuffer; + } + + /** + * Mutable pair object returned by {@link UnsafeFixedWidthAggregationMap#iterator()}. + */ + public static class MapEntry { + private MapEntry() { }; + public final UnsafeRow key = new UnsafeRow(); + public final UnsafeRow value = new UnsafeRow(); + } + + /** + * Returns an iterator over the keys and values in this map. + * + * For efficiency, each call returns the same object. + */ + public Iterator<MapEntry> iterator() { + return new Iterator<MapEntry>() { + + private final MapEntry entry = new MapEntry(); + private final Iterator<BytesToBytesMap.Location> mapLocationIterator = map.iterator(); + + @Override + public boolean hasNext() { + return mapLocationIterator.hasNext(); + } + + @Override + public MapEntry next() { + final BytesToBytesMap.Location loc = mapLocationIterator.next(); + final MemoryLocation keyAddress = loc.getKeyAddress(); + final MemoryLocation valueAddress = loc.getValueAddress(); + entry.key.pointTo( + keyAddress.getBaseObject(), + keyAddress.getBaseOffset(), + groupingKeySchema.length(), + groupingKeySchema + ); + entry.value.pointTo( + valueAddress.getBaseObject(), + valueAddress.getBaseOffset(), + aggregationBufferSchema.length(), + aggregationBufferSchema + ); + return entry; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + + /** + * Free the unsafe memory associated with this map. + */ + public void free() { + map.free(); + } + + @SuppressWarnings("UseOfSystemOutOrSystemErr") + public void printPerfMetrics() { + if (!enablePerfMetrics) { + throw new IllegalStateException("Perf metrics not enabled"); + } + System.out.println("Average probes per lookup: " + map.getAverageProbesPerLookup()); + System.out.println("Number of hash collisions: " + map.getNumHashCollisions()); + System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs()); + System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption()); + } + +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java new file mode 100644 index 0000000000000000000000000000000000000000..0a358ed408aa1f12f304a7e8fbb8ba22022cd1c1 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -0,0 +1,435 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import scala.collection.Map; +import scala.collection.Seq; +import scala.collection.mutable.ArraySeq; + +import javax.annotation.Nullable; +import java.math.BigDecimal; +import java.sql.Date; +import java.util.*; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.DataType; +import static org.apache.spark.sql.types.DataTypes.*; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.UTF8String; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.bitset.BitSetMethods; + +/** + * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. + * + * Each tuple has three parts: [null bit set] [values] [variable length portion] + * + * The bit set is used for null tracking and is aligned to 8-byte word boundaries. It stores + * one bit per field. + * + * In the `values` region, we store one 8-byte word per field. For fields that hold fixed-length + * primitive types, such as long, double, or int, we store the value directly in the word. For + * fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the + * base address of the row) that points to the beginning of the variable-length field. + * + * Instances of `UnsafeRow` act as pointers to row data stored in this format. + */ +public final class UnsafeRow implements MutableRow { + + private Object baseObject; + private long baseOffset; + + Object getBaseObject() { return baseObject; } + long getBaseOffset() { return baseOffset; } + + /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ + private int numFields; + + /** The width of the null tracking bit set, in bytes */ + private int bitSetWidthInBytes; + /** + * This optional schema is required if you want to call generic get() and set() methods on + * this UnsafeRow, but is optional if callers will only use type-specific getTYPE() and setTYPE() + * methods. This should be removed after the planned InternalRow / Row split; right now, it's only + * needed by the generic get() method, which is only called internally by code that accesses + * UTF8String-typed columns. + */ + @Nullable + private StructType schema; + + private long getFieldOffset(int ordinal) { + return baseOffset + bitSetWidthInBytes + ordinal * 8L; + } + + public static int calculateBitSetWidthInBytes(int numFields) { + return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8; + } + + /** + * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types) + */ + public static final Set<DataType> settableFieldTypes; + + /** + * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException). + */ + public static final Set<DataType> readableFieldTypes; + + static { + settableFieldTypes = Collections.unmodifiableSet( + new HashSet<DataType>( + Arrays.asList(new DataType[] { + NullType, + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType + }))); + + // We support get() on a superset of the types for which we support set(): + final Set<DataType> _readableFieldTypes = new HashSet<DataType>( + Arrays.asList(new DataType[]{ + StringType + })); + _readableFieldTypes.addAll(settableFieldTypes); + readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); + } + + /** + * Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called, + * since the value returned by this constructor is equivalent to a null pointer. + */ + public UnsafeRow() { } + + /** + * Update this UnsafeRow to point to different backing data. + * + * @param baseObject the base object + * @param baseOffset the offset within the base object + * @param numFields the number of fields in this row + * @param schema an optional schema; this is necessary if you want to call generic get() or set() + * methods on this row, but is optional if the caller will only use type-specific + * getTYPE() and setTYPE() methods. + */ + public void pointTo( + Object baseObject, + long baseOffset, + int numFields, + @Nullable StructType schema) { + assert numFields >= 0 : "numFields should >= 0"; + assert schema == null || schema.fields().length == numFields; + this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); + this.baseObject = baseObject; + this.baseOffset = baseOffset; + this.numFields = numFields; + this.schema = schema; + } + + private void assertIndexIsValid(int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + assert index < numFields : "index (" + index + ") should <= " + numFields; + } + + @Override + public void setNullAt(int i) { + assertIndexIsValid(i); + BitSetMethods.set(baseObject, baseOffset, i); + // To preserve row equality, zero out the value when setting the column to null. + // Since this row does does not currently support updates to variable-length values, we don't + // have to worry about zeroing out that data. + PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(i), 0); + } + + private void setNotNullAt(int i) { + assertIndexIsValid(i); + BitSetMethods.unset(baseObject, baseOffset, i); + } + + @Override + public void update(int ordinal, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setInt(int ordinal, int value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putInt(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setLong(int ordinal, long value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setDouble(int ordinal, double value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setBoolean(int ordinal, boolean value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putBoolean(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setShort(int ordinal, short value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putShort(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setByte(int ordinal, byte value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putByte(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setFloat(int ordinal, float value) { + assertIndexIsValid(ordinal); + setNotNullAt(ordinal); + PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); + } + + @Override + public void setString(int ordinal, String value) { + throw new UnsupportedOperationException(); + } + + @Override + public int size() { + return numFields; + } + + @Override + public int length() { + return size(); + } + + @Override + public StructType schema() { + return schema; + } + + @Override + public Object apply(int i) { + return get(i); + } + + @Override + public Object get(int i) { + assertIndexIsValid(i); + assert (schema != null) : "Schema must be defined when calling generic get() method"; + final DataType dataType = schema.fields()[i].dataType(); + // UnsafeRow is only designed to be invoked by internal code, which only invokes this generic + // get() method when trying to access UTF8String-typed columns. If we refactor the codebase to + // separate the internal and external row interfaces, then internal code can fetch strings via + // a new getUTF8String() method and we'll be able to remove this method. + if (isNullAt(i)) { + return null; + } else if (dataType == StringType) { + return getUTF8String(i); + } else { + throw new UnsupportedOperationException(); + } + } + + @Override + public boolean isNullAt(int i) { + assertIndexIsValid(i); + return BitSetMethods.isSet(baseObject, baseOffset, i); + } + + @Override + public boolean getBoolean(int i) { + assertIndexIsValid(i); + return PlatformDependent.UNSAFE.getBoolean(baseObject, getFieldOffset(i)); + } + + @Override + public byte getByte(int i) { + assertIndexIsValid(i); + return PlatformDependent.UNSAFE.getByte(baseObject, getFieldOffset(i)); + } + + @Override + public short getShort(int i) { + assertIndexIsValid(i); + return PlatformDependent.UNSAFE.getShort(baseObject, getFieldOffset(i)); + } + + @Override + public int getInt(int i) { + assertIndexIsValid(i); + return PlatformDependent.UNSAFE.getInt(baseObject, getFieldOffset(i)); + } + + @Override + public long getLong(int i) { + assertIndexIsValid(i); + return PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i)); + } + + @Override + public float getFloat(int i) { + assertIndexIsValid(i); + if (isNullAt(i)) { + return Float.NaN; + } else { + return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(i)); + } + } + + @Override + public double getDouble(int i) { + assertIndexIsValid(i); + if (isNullAt(i)) { + return Float.NaN; + } else { + return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(i)); + } + } + + public UTF8String getUTF8String(int i) { + assertIndexIsValid(i); + final UTF8String str = new UTF8String(); + final long offsetToStringSize = getLong(i); + final int stringSizeInBytes = + (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offsetToStringSize); + final byte[] strBytes = new byte[stringSizeInBytes]; + PlatformDependent.copyMemory( + baseObject, + baseOffset + offsetToStringSize + 8, // The `+ 8` is to skip past the size to get the data + strBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + stringSizeInBytes + ); + str.set(strBytes); + return str; + } + + @Override + public String getString(int i) { + return getUTF8String(i).toString(); + } + + @Override + public BigDecimal getDecimal(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Date getDate(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public <T> Seq<T> getSeq(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public <T> List<T> getList(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public <K, V> Map<K, V> getMap(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public <T> scala.collection.immutable.Map<String, T> getValuesMap(Seq<String> fieldNames) { + throw new UnsupportedOperationException(); + } + + @Override + public <K, V> java.util.Map<K, V> getJavaMap(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public Row getStruct(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public <T> T getAs(int i) { + throw new UnsupportedOperationException(); + } + + @Override + public <T> T getAs(String fieldName) { + throw new UnsupportedOperationException(); + } + + @Override + public int fieldIndex(String name) { + throw new UnsupportedOperationException(); + } + + @Override + public Row copy() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean anyNull() { + return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes); + } + + @Override + public Seq<Object> toSeq() { + final ArraySeq<Object> values = new ArraySeq<Object>(numFields); + for (int fieldNumber = 0; fieldNumber < numFields; fieldNumber++) { + values.update(fieldNumber, get(fieldNumber)); + } + return values; + } + + @Override + public String toString() { + return mkString("[", ",", "]"); + } + + @Override + public String mkString() { + return toSeq().mkString(); + } + + @Override + public String mkString(String sep) { + return toSeq().mkString(sep); + } + + @Override + public String mkString(String start, String sep, String end) { + return toSeq().mkString(start, sep, end); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala new file mode 100644 index 0000000000000000000000000000000000000000..5b2c8572784bdf7c36cb6c90176e9144df02cb26 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.array.ByteArrayMethods + +/** + * Converts Rows into UnsafeRow format. This class is NOT thread-safe. + * + * @param fieldTypes the data types of the row's columns. + */ +class UnsafeRowConverter(fieldTypes: Array[DataType]) { + + def this(schema: StructType) { + this(schema.fields.map(_.dataType)) + } + + /** Re-used pointer to the unsafe row being written */ + private[this] val unsafeRow = new UnsafeRow() + + /** Functions for encoding each column */ + private[this] val writers: Array[UnsafeColumnWriter] = { + fieldTypes.map(t => UnsafeColumnWriter.forType(t)) + } + + /** The size, in bytes, of the fixed-length portion of the row, including the null bitmap */ + private[this] val fixedLengthSize: Int = + (8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length) + + /** + * Compute the amount of space, in bytes, required to encode the given row. + */ + def getSizeRequirement(row: Row): Int = { + var fieldNumber = 0 + var variableLengthFieldSize: Int = 0 + while (fieldNumber < writers.length) { + if (!row.isNullAt(fieldNumber)) { + variableLengthFieldSize += writers(fieldNumber).getSize(row, fieldNumber) + } + fieldNumber += 1 + } + fixedLengthSize + variableLengthFieldSize + } + + /** + * Convert the given row into UnsafeRow format. + * + * @param row the row to convert + * @param baseObject the base object of the destination address + * @param baseOffset the base offset of the destination address + * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`. + */ + def writeRow(row: Row, baseObject: Object, baseOffset: Long): Long = { + unsafeRow.pointTo(baseObject, baseOffset, writers.length, null) + var fieldNumber = 0 + var appendCursor: Int = fixedLengthSize + while (fieldNumber < writers.length) { + if (row.isNullAt(fieldNumber)) { + unsafeRow.setNullAt(fieldNumber) + } else { + appendCursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, appendCursor) + } + fieldNumber += 1 + } + appendCursor + } + +} + +/** + * Function for writing a column into an UnsafeRow. + */ +private abstract class UnsafeColumnWriter { + /** + * Write a value into an UnsafeRow. + * + * @param source the row being converted + * @param target a pointer to the converted unsafe row + * @param column the column to write + * @param appendCursor the offset from the start of the unsafe row to the end of the row; + * used for calculating where variable-length data should be written + * @return the number of variable-length bytes written + */ + def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int + + /** + * Return the number of bytes that are needed to write this variable-length value. + */ + def getSize(source: Row, column: Int): Int +} + +private object UnsafeColumnWriter { + + def forType(dataType: DataType): UnsafeColumnWriter = { + dataType match { + case NullType => NullUnsafeColumnWriter + case BooleanType => BooleanUnsafeColumnWriter + case ByteType => ByteUnsafeColumnWriter + case ShortType => ShortUnsafeColumnWriter + case IntegerType => IntUnsafeColumnWriter + case LongType => LongUnsafeColumnWriter + case FloatType => FloatUnsafeColumnWriter + case DoubleType => DoubleUnsafeColumnWriter + case StringType => StringUnsafeColumnWriter + case t => + throw new UnsupportedOperationException(s"Do not know how to write columns of type $t") + } + } +} + +// ------------------------------------------------------------------------------------------------ + +private object NullUnsafeColumnWriter extends NullUnsafeColumnWriter +private object BooleanUnsafeColumnWriter extends BooleanUnsafeColumnWriter +private object ByteUnsafeColumnWriter extends ByteUnsafeColumnWriter +private object ShortUnsafeColumnWriter extends ShortUnsafeColumnWriter +private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter +private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter +private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter +private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter +private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter + +private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { + // Primitives don't write to the variable-length region: + def getSize(sourceRow: Row, column: Int): Int = 0 +} + +private class NullUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setNullAt(column) + 0 + } +} + +private class BooleanUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setBoolean(column, source.getBoolean(column)) + 0 + } +} + +private class ByteUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setByte(column, source.getByte(column)) + 0 + } +} + +private class ShortUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setShort(column, source.getShort(column)) + 0 + } +} + +private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setInt(column, source.getInt(column)) + 0 + } +} + +private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setLong(column, source.getLong(column)) + 0 + } +} + +private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setFloat(column, source.getFloat(column)) + 0 + } +} + +private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setDouble(column, source.getDouble(column)) + 0 + } +} + +private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter { + def getSize(source: Row, column: Int): Int = { + val numBytes = source.get(column).asInstanceOf[UTF8String].getBytes.length + 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + } + + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + val value = source.get(column).asInstanceOf[UTF8String] + val baseObject = target.getBaseObject + val baseOffset = target.getBaseOffset + val numBytes = value.getBytes.length + PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes) + PlatformDependent.copyMemory( + value.getBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + baseOffset + appendCursor + 8, + numBytes + ) + target.setLong(column, appendCursor) + 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..7a19e511eb8b566ca1981d4f0be78d2db735a190 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.collection.JavaConverters._ +import scala.util.Random + +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator} +import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers} + +import org.apache.spark.sql.types._ + +class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers with BeforeAndAfterEach { + + import UnsafeFixedWidthAggregationMap._ + + private val groupKeySchema = StructType(StructField("product", StringType) :: Nil) + private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil) + private def emptyAggregationBuffer: Row = new GenericRow(Array[Any](0)) + + private var memoryManager: TaskMemoryManager = null + + override def beforeEach(): Unit = { + memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + } + + override def afterEach(): Unit = { + if (memoryManager != null) { + memoryManager.cleanUpAllAllocatedMemory() + memoryManager = null + } + } + + test("supported schemas") { + assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) + assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil))) + + assert( + !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) + assert( + !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) + } + + test("empty map") { + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + memoryManager, + 1024, // initial capacity + false // disable perf metrics + ) + assert(!map.iterator().hasNext) + map.free() + } + + test("updating values for a single key") { + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + memoryManager, + 1024, // initial capacity + false // disable perf metrics + ) + val groupKey = new GenericRow(Array[Any](UTF8String("cats"))) + + // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts) + map.getAggregationBuffer(groupKey) + val iter = map.iterator() + val entry = iter.next() + assert(!iter.hasNext) + entry.key.getString(0) should be ("cats") + entry.value.getInt(0) should be (0) + + // Modifications to rows retrieved from the map should update the values in the map + entry.value.setInt(0, 42) + map.getAggregationBuffer(groupKey).getInt(0) should be (42) + + map.free() + } + + test("inserting large random keys") { + val map = new UnsafeFixedWidthAggregationMap( + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, + memoryManager, + 128, // initial capacity + false // disable perf metrics + ) + val rand = new Random(42) + val groupKeys: Set[String] = Seq.fill(512)(rand.nextString(1024)).toSet + groupKeys.foreach { keyString => + map.getAggregationBuffer(new GenericRow(Array[Any](UTF8String(keyString)))) + } + val seenKeys: Set[String] = map.iterator().asScala.map { entry => + entry.key.getString(0) + }.toSet + seenKeys.size should be (groupKeys.size) + seenKeys should be (groupKeys) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..3a60c7fd32675988b58ae4e54ffd83ca5066887d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.util.Arrays + +import org.scalatest.{FunSuite, Matchers} + +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.array.ByteArrayMethods + +class UnsafeRowConverterSuite extends FunSuite with Matchers { + + test("basic conversion with only primitive types") { + val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) + val converter = new UnsafeRowConverter(fieldTypes) + + val row = new SpecificMutableRow(fieldTypes) + row.setLong(0, 0) + row.setLong(1, 1) + row.setInt(2, 2) + + val sizeRequired: Int = converter.getSizeRequirement(row) + sizeRequired should be (8 + (3 * 8)) + val buffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) + numBytesWritten should be (sizeRequired) + + val unsafeRow = new UnsafeRow() + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + unsafeRow.getLong(0) should be (0) + unsafeRow.getLong(1) should be (1) + unsafeRow.getInt(2) should be (2) + } + + test("basic conversion with primitive and string types") { + val fieldTypes: Array[DataType] = Array(LongType, StringType, StringType) + val converter = new UnsafeRowConverter(fieldTypes) + + val row = new SpecificMutableRow(fieldTypes) + row.setLong(0, 0) + row.setString(1, "Hello") + row.setString(2, "World") + + val sizeRequired: Int = converter.getSizeRequirement(row) + sizeRequired should be (8 + (8 * 3) + + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) + + ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length + 8)) + val buffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) + numBytesWritten should be (sizeRequired) + + val unsafeRow = new UnsafeRow() + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + unsafeRow.getLong(0) should be (0) + unsafeRow.getString(1) should be ("Hello") + unsafeRow.getString(2) should be ("World") + } + + test("null handling") { + val fieldTypes: Array[DataType] = Array( + NullType, + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType) + val converter = new UnsafeRowConverter(fieldTypes) + + val rowWithAllNullColumns: Row = { + val r = new SpecificMutableRow(fieldTypes) + for (i <- 0 to fieldTypes.length - 1) { + r.setNullAt(i) + } + r + } + + val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns) + val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow( + rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET) + numBytesWritten should be (sizeRequired) + + val createdFromNull = new UnsafeRow() + createdFromNull.pointTo( + createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + for (i <- 0 to fieldTypes.length - 1) { + assert(createdFromNull.isNullAt(i)) + } + createdFromNull.getBoolean(1) should be (false) + createdFromNull.getByte(2) should be (0) + createdFromNull.getShort(3) should be (0) + createdFromNull.getInt(4) should be (0) + createdFromNull.getLong(5) should be (0) + assert(java.lang.Float.isNaN(createdFromNull.getFloat(6))) + assert(java.lang.Double.isNaN(createdFromNull.getFloat(7))) + + // If we have an UnsafeRow with columns that are initially non-null and we null out those + // columns, then the serialized row representation should be identical to what we would get by + // creating an entirely null row via the converter + val rowWithNoNullColumns: Row = { + val r = new SpecificMutableRow(fieldTypes) + r.setNullAt(0) + r.setBoolean(1, false) + r.setByte(2, 20) + r.setShort(3, 30) + r.setInt(4, 400) + r.setLong(5, 500) + r.setFloat(6, 600) + r.setDouble(7, 700) + r + } + val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8) + converter.writeRow( + rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET) + val setToNullAfterCreation = new UnsafeRow() + setToNullAfterCreation.pointTo( + setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + + setToNullAfterCreation.isNullAt(0) should be (rowWithNoNullColumns.isNullAt(0)) + setToNullAfterCreation.getBoolean(1) should be (rowWithNoNullColumns.getBoolean(1)) + setToNullAfterCreation.getByte(2) should be (rowWithNoNullColumns.getByte(2)) + setToNullAfterCreation.getShort(3) should be (rowWithNoNullColumns.getShort(3)) + setToNullAfterCreation.getInt(4) should be (rowWithNoNullColumns.getInt(4)) + setToNullAfterCreation.getLong(5) should be (rowWithNoNullColumns.getLong(5)) + setToNullAfterCreation.getFloat(6) should be (rowWithNoNullColumns.getFloat(6)) + setToNullAfterCreation.getDouble(7) should be (rowWithNoNullColumns.getDouble(7)) + + for (i <- 0 to fieldTypes.length - 1) { + setToNullAfterCreation.setNullAt(i) + } + assert(Arrays.equals(createdFromNullBuffer, setToNullAfterCreationBuffer)) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 4fc5de7e824fe3712a502a42a6cb9c08f81be8a0..2fa602a6082dca9431846baf1e8a730ffc15f3b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -30,6 +30,7 @@ private[spark] object SQLConf { val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" val CODEGEN_ENABLED = "spark.sql.codegen" + val UNSAFE_ENABLED = "spark.sql.unsafe.enabled" val DIALECT = "spark.sql.dialect" val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString" @@ -149,6 +150,14 @@ private[sql] class SQLConf extends Serializable { */ private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean + /** + * When set to true, Spark SQL will use managed memory for certain operations. This option only + * takes effect if codegen is enabled. + * + * Defaults to false as this feature is currently experimental. + */ + private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, "false").toBoolean + private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a279b0f07c38a0f3d9f89a09cf00bcfd33e5984f..bd4a55fa132fb2d5dbc4845fc69152b402a28a7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1011,6 +1011,8 @@ class SQLContext(@transient val sparkContext: SparkContext) def codegenEnabled: Boolean = self.conf.codegenEnabled + def unsafeEnabled: Boolean = self.conf.unsafeEnabled + def numPartitions: Int = self.conf.numShufflePartitions def strategies: Seq[Strategy] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index b1ef6556de1e9e3c38a5dd2cc8aae53333ac7f78..5d9f202681045b6b142ba16d467e1dff0d2baa70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.TaskContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.trees._ @@ -40,6 +41,7 @@ case class AggregateEvaluation( * ensure all values where `groupingExpressions` are equal are present. * @param groupingExpressions expressions that are evaluated to determine grouping. * @param aggregateExpressions expressions that are computed for each group. + * @param unsafeEnabled whether to allow Unsafe-based aggregation buffers to be used. * @param child the input data source. */ @DeveloperApi @@ -47,6 +49,7 @@ case class GeneratedAggregate( partial: Boolean, groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], + unsafeEnabled: Boolean, child: SparkPlan) extends UnaryNode { @@ -225,6 +228,21 @@ case class GeneratedAggregate( case e: Expression if groupMap.contains(e) => groupMap(e) }) + val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema) + + val groupKeySchema: StructType = { + val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) => + // This is a dummy field name + StructField(idx.toString, expr.dataType, expr.nullable) + } + StructType(fields) + } + + val schemaSupportsUnsafe: Boolean = { + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && + UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema) + } + child.execute().mapPartitions { iter => // Builds a new custom class for holding the results of aggregation for a group. val initialValues = computeFunctions.flatMap(_.initialValues) @@ -265,7 +283,49 @@ case class GeneratedAggregate( val resultProjection = resultProjectionBuilder() Iterator(resultProjection(buffer)) + } else if (unsafeEnabled && schemaSupportsUnsafe) { + log.info("Using Unsafe-based aggregator") + val aggregationMap = new UnsafeFixedWidthAggregationMap( + newAggregationBuffer(EmptyRow), + aggregationBufferSchema, + groupKeySchema, + TaskContext.get.taskMemoryManager(), + 1024 * 16, // initial capacity + false // disable tracking of performance metrics + ) + + while (iter.hasNext) { + val currentRow: Row = iter.next() + val groupKey: Row = groupProjection(currentRow) + val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey) + updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow)) + } + + new Iterator[Row] { + private[this] val mapIterator = aggregationMap.iterator() + private[this] val resultProjection = resultProjectionBuilder() + + def hasNext: Boolean = mapIterator.hasNext + + def next(): Row = { + val entry = mapIterator.next() + val result = resultProjection(joinedRow(entry.key, entry.value)) + if (hasNext) { + result + } else { + // This is the last element in the iterator, so let's free the buffer. Before we do, + // though, we need to make a defensive copy of the result so that we don't return an + // object that might contain dangling pointers to the freed memory + val resultCopy = result.copy() + aggregationMap.free() + resultCopy + } + } + } } else { + if (unsafeEnabled) { + log.info("Not using Unsafe-based aggregator because it is not supported for this schema") + } val buffers = new java.util.HashMap[Row, MutableRow]() var currentRow: Row = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 3a0a6c86700a8bfb644ba9b95c92107e7c15e509..af58911cc0e6a55c89dc295c40419f63a703e640 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -136,10 +136,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { partial = false, namedGroupingAttributes, rewrittenAggregateExpressions, + unsafeEnabled, execution.GeneratedAggregate( partial = true, groupingExpressions, partialComputation, + unsafeEnabled, planLater(child))) :: Nil // Cases where some aggregate can not be codegened diff --git a/unsafe/pom.xml b/unsafe/pom.xml new file mode 100644 index 0000000000000000000000000000000000000000..8901d77591932292a7db3c4ae349bad9bb9a04b8 --- /dev/null +++ b/unsafe/pom.xml @@ -0,0 +1,69 @@ +<?xml version="1.0" encoding="UTF-8"?> +<!-- + ~ Licensed to the Apache Software Foundation (ASF) under one or more + ~ contributor license agreements. See the NOTICE file distributed with + ~ this work for additional information regarding copyright ownership. + ~ The ASF licenses this file to You under the Apache License, Version 2.0 + ~ (the "License"); you may not use this file except in compliance with + ~ the License. You may obtain a copy of the License at + ~ + ~ http://www.apache.org/licenses/LICENSE-2.0 + ~ + ~ Unless required by applicable law or agreed to in writing, software + ~ distributed under the License is distributed on an "AS IS" BASIS, + ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + ~ See the License for the specific language governing permissions and + ~ limitations under the License. + --> + +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + <parent> + <groupId>org.apache.spark</groupId> + <artifactId>spark-parent_2.10</artifactId> + <version>1.4.0-SNAPSHOT</version> + <relativePath>../pom.xml</relativePath> + </parent> + + <groupId>org.apache.spark</groupId> + <artifactId>spark-unsafe_2.10</artifactId> + <packaging>jar</packaging> + <name>Spark Project Unsafe</name> + <url>http://spark.apache.org/</url> + <properties> + <sbt.project.name>unsafe</sbt.project.name> + </properties> + + <dependencies> + + <!-- Core dependencies --> + <dependency> + <groupId>com.google.code.findbugs</groupId> + <artifactId>jsr305</artifactId> + </dependency> + + <!-- Provided dependencies --> + <dependency> + <groupId>org.slf4j</groupId> + <artifactId>slf4j-api</artifactId> + <scope>provided</scope> + </dependency> + + <!-- Test dependencies --> + <dependency> + <groupId>junit</groupId> + <artifactId>junit</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>com.novocode</groupId> + <artifactId>junit-interface</artifactId> + <scope>test</scope> + </dependency> + </dependencies> + <build> + <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory> + <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory> + </build> +</project> diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java new file mode 100644 index 0000000000000000000000000000000000000000..91b2f9aa439211b842410a26ff49fdd139a51e35 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe; + +import java.lang.reflect.Field; + +import sun.misc.Unsafe; + +public final class PlatformDependent { + + public static final Unsafe UNSAFE; + + public static final int BYTE_ARRAY_OFFSET; + + public static final int INT_ARRAY_OFFSET; + + public static final int LONG_ARRAY_OFFSET; + + public static final int DOUBLE_ARRAY_OFFSET; + + /** + * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to + * allow safepoint polling during a large copy. + */ + private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L; + + static { + sun.misc.Unsafe unsafe; + try { + Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe"); + unsafeField.setAccessible(true); + unsafe = (sun.misc.Unsafe) unsafeField.get(null); + } catch (Throwable cause) { + unsafe = null; + } + UNSAFE = unsafe; + + if (UNSAFE != null) { + BYTE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(byte[].class); + INT_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(int[].class); + LONG_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(long[].class); + DOUBLE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(double[].class); + } else { + BYTE_ARRAY_OFFSET = 0; + INT_ARRAY_OFFSET = 0; + LONG_ARRAY_OFFSET = 0; + DOUBLE_ARRAY_OFFSET = 0; + } + } + + static public void copyMemory( + Object src, + long srcOffset, + Object dst, + long dstOffset, + long length) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + srcOffset += size; + dstOffset += size; + } + } + + /** + * Raises an exception bypassing compiler checks for checked exceptions. + */ + public static void throwException(Throwable t) { + UNSAFE.throwException(t); + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java new file mode 100644 index 0000000000000000000000000000000000000000..53eadf96a6b52f98d4aa10b3f5db9c560a890591 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.array; + +import org.apache.spark.unsafe.PlatformDependent; + +public class ByteArrayMethods { + + private ByteArrayMethods() { + // Private constructor, since this class only contains static methods. + } + + public static int roundNumberOfBytesToNearestWord(int numBytes) { + int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` + if (remainder == 0) { + return numBytes; + } else { + return numBytes + (8 - remainder); + } + } + + /** + * Optimized byte array equality check for 8-byte-word-aligned byte arrays. + * @return true if the arrays are equal, false otherwise + */ + public static boolean wordAlignedArrayEquals( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset, + long arrayLengthInBytes) { + for (int i = 0; i < arrayLengthInBytes; i += 8) { + final long left = + PlatformDependent.UNSAFE.getLong(leftBaseObject, leftBaseOffset + i); + final long right = + PlatformDependent.UNSAFE.getLong(rightBaseObject, rightBaseOffset + i); + if (left != right) return false; + } + return true; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java new file mode 100644 index 0000000000000000000000000000000000000000..18d1f0d2d7eb21381c1336f23363d24700baa1fe --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.array; + +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.memory.MemoryBlock; + +/** + * An array of long values. Compared with native JVM arrays, this: + * <ul> + * <li>supports using both in-heap and off-heap memory</li> + * <li>has no bound checking, and thus can crash the JVM process when assert is turned off</li> + * </ul> + */ +public final class LongArray { + + // This is a long so that we perform long multiplications when computing offsets. + private static final long WIDTH = 8; + + private final MemoryBlock memory; + private final Object baseObj; + private final long baseOffset; + + private final long length; + + public LongArray(MemoryBlock memory) { + assert memory.size() % WIDTH == 0 : "Memory not aligned (" + memory.size() + ")"; + assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size > 4 billion elements"; + this.memory = memory; + this.baseObj = memory.getBaseObject(); + this.baseOffset = memory.getBaseOffset(); + this.length = memory.size() / WIDTH; + } + + public MemoryBlock memoryBlock() { + return memory; + } + + /** + * Returns the number of elements this array can hold. + */ + public long size() { + return length; + } + + /** + * Sets the value at position {@code index}. + */ + public void set(int index, long value) { + assert index >= 0 : "index (" + index + ") should >= 0"; + assert index < length : "index (" + index + ") should < length (" + length + ")"; + PlatformDependent.UNSAFE.putLong(baseObj, baseOffset + index * WIDTH, value); + } + + /** + * Returns the value at position {@code index}. + */ + public long get(int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + assert index < length : "index (" + index + ") should < length (" + length + ")"; + return PlatformDependent.UNSAFE.getLong(baseObj, baseOffset + index * WIDTH); + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java new file mode 100644 index 0000000000000000000000000000000000000000..f72e07fce92fdcfb70c7fbcd63a105941a3207fb --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.bitset; + +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; + +/** + * A fixed size uncompressed bit set backed by a {@link LongArray}. + * + * Each bit occupies exactly one bit of storage. + */ +public final class BitSet { + + /** A long array for the bits. */ + private final LongArray words; + + /** Length of the long array. */ + private final int numWords; + + private final Object baseObject; + private final long baseOffset; + + /** + * Creates a new {@link BitSet} using the specified memory block. Size of the memory block must be + * multiple of 8 bytes (i.e. 64 bits). + */ + public BitSet(MemoryBlock memory) { + words = new LongArray(memory); + assert (words.size() <= Integer.MAX_VALUE); + numWords = (int) words.size(); + baseObject = words.memoryBlock().getBaseObject(); + baseOffset = words.memoryBlock().getBaseOffset(); + } + + public MemoryBlock memoryBlock() { + return words.memoryBlock(); + } + + /** + * Returns the number of bits in this {@code BitSet}. + */ + public long capacity() { + return numWords * 64; + } + + /** + * Sets the bit at the specified index to {@code true}. + */ + public void set(int index) { + assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")"; + BitSetMethods.set(baseObject, baseOffset, index); + } + + /** + * Sets the bit at the specified index to {@code false}. + */ + public void unset(int index) { + assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")"; + BitSetMethods.unset(baseObject, baseOffset, index); + } + + /** + * Returns {@code true} if the bit is set at the specified index. + */ + public boolean isSet(int index) { + assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")"; + return BitSetMethods.isSet(baseObject, baseOffset, index); + } + + /** + * Returns the index of the first bit that is set to true that occurs on or after the + * specified starting index. If no such bit exists then {@code -1} is returned. + * <p> + * To iterate over the true bits in a BitSet, use the following loop: + * <pre> + * <code> + * for (long i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i + 1)) { + * // operate on index i here + * } + * </code> + * </pre> + * + * @param fromIndex the index to start checking from (inclusive) + * @return the index of the next set bit, or -1 if there is no such bit + */ + public int nextSetBit(int fromIndex) { + return BitSetMethods.nextSetBit(baseObject, baseOffset, fromIndex, numWords); + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java new file mode 100644 index 0000000000000000000000000000000000000000..f30626d8f4317765a344aa5e36d3c8310795785d --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.bitset; + +import org.apache.spark.unsafe.PlatformDependent; + +/** + * Methods for working with fixed-size uncompressed bitsets. + * + * We assume that the bitset data is word-aligned (that is, a multiple of 8 bytes in length). + * + * Each bit occupies exactly one bit of storage. + */ +public final class BitSetMethods { + + private static final long WORD_SIZE = 8; + + private BitSetMethods() { + // Make the default constructor private, since this only holds static methods. + } + + /** + * Sets the bit at the specified index to {@code true}. + */ + public static void set(Object baseObject, long baseOffset, int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + final long mask = 1L << (index & 0x3f); // mod 64 and shift + final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; + final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); + PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word | mask); + } + + /** + * Sets the bit at the specified index to {@code false}. + */ + public static void unset(Object baseObject, long baseOffset, int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + final long mask = 1L << (index & 0x3f); // mod 64 and shift + final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; + final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); + PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word & ~mask); + } + + /** + * Returns {@code true} if the bit is set at the specified index. + */ + public static boolean isSet(Object baseObject, long baseOffset, int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + final long mask = 1L << (index & 0x3f); // mod 64 and shift + final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE; + final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset); + return (word & mask) != 0; + } + + /** + * Returns {@code true} if any bit is set. + */ + public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidthInBytes) { + for (int i = 0; i <= bitSetWidthInBytes; i++) { + if (PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + i) != 0) { + return true; + } + } + return false; + } + + /** + * Returns the index of the first bit that is set to true that occurs on or after the + * specified starting index. If no such bit exists then {@code -1} is returned. + * <p> + * To iterate over the true bits in a BitSet, use the following loop: + * <pre> + * <code> + * for (long i = bs.nextSetBit(0, sizeInWords); i >= 0; i = bs.nextSetBit(i + 1, sizeInWords)) { + * // operate on index i here + * } + * </code> + * </pre> + * + * @param fromIndex the index to start checking from (inclusive) + * @param bitsetSizeInWords the size of the bitset, measured in 8-byte words + * @return the index of the next set bit, or -1 if there is no such bit + */ + public static int nextSetBit( + Object baseObject, + long baseOffset, + int fromIndex, + int bitsetSizeInWords) { + int wi = fromIndex >> 6; + if (wi >= bitsetSizeInWords) { + return -1; + } + + // Try to find the next set bit in the current word + final int subIndex = fromIndex & 0x3f; + long word = + PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE) >> subIndex; + if (word != 0) { + return (wi << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word); + } + + // Find the next set bit in the rest of the words + wi += 1; + while (wi < bitsetSizeInWords) { + word = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE); + if (word != 0) { + return (wi << 6) + java.lang.Long.numberOfTrailingZeros(word); + } + wi += 1; + } + + return -1; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java new file mode 100644 index 0000000000000000000000000000000000000000..85cd02469adb76defa5b94922aa297413d318167 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.hash; + +import org.apache.spark.unsafe.PlatformDependent; + +/** + * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. + */ +public final class Murmur3_x86_32 { + private static final int C1 = 0xcc9e2d51; + private static final int C2 = 0x1b873593; + + private final int seed; + + public Murmur3_x86_32(int seed) { + this.seed = seed; + } + + @Override + public String toString() { + return "Murmur3_32(seed=" + seed + ")"; + } + + public int hashInt(int input) { + int k1 = mixK1(input); + int h1 = mixH1(seed, k1); + + return fmix(h1, 4); + } + + public int hashUnsafeWords(Object baseObject, long baseOffset, int lengthInBytes) { + // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. + assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; + int h1 = seed; + for (int offset = 0; offset < lengthInBytes; offset += 4) { + int halfWord = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset); + int k1 = mixK1(halfWord); + h1 = mixH1(h1, k1); + } + return fmix(h1, lengthInBytes); + } + + public int hashLong(long input) { + int low = (int) input; + int high = (int) (input >>> 32); + + int k1 = mixK1(low); + int h1 = mixH1(seed, k1); + + k1 = mixK1(high); + h1 = mixH1(h1, k1); + + return fmix(h1, 8); + } + + private static int mixK1(int k1) { + k1 *= C1; + k1 = Integer.rotateLeft(k1, 15); + k1 *= C2; + return k1; + } + + private static int mixH1(int h1, int k1) { + h1 ^= k1; + h1 = Integer.rotateLeft(h1, 13); + h1 = h1 * 5 + 0xe6546b64; + return h1; + } + + // Finalization mix - force all bits of a hash block to avalanche + private static int fmix(int h1, int length) { + h1 ^= length; + h1 ^= h1 >>> 16; + h1 *= 0x85ebca6b; + h1 ^= h1 >>> 13; + h1 *= 0xc2b2ae35; + h1 ^= h1 >>> 16; + return h1; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java new file mode 100644 index 0000000000000000000000000000000000000000..85b64c083380380d1b2b999f5d199ddc49dfd86d --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -0,0 +1,549 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +import java.lang.Override; +import java.lang.UnsupportedOperationException; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import org.apache.spark.unsafe.*; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.bitset.BitSet; +import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.memory.*; + +/** + * An append-only hash map where keys and values are contiguous regions of bytes. + * <p> + * This is backed by a power-of-2-sized hash table, using quadratic probing with triangular numbers, + * which is guaranteed to exhaust the space. + * <p> + * The map can support up to 2^31 keys because we use 32 bit MurmurHash. If the key cardinality is + * higher than this, you should probably be using sorting instead of hashing for better cache + * locality. + * <p> + * This class is not thread safe. + */ +public final class BytesToBytesMap { + + private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0); + + private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING; + + private final TaskMemoryManager memoryManager; + + /** + * A linked list for tracking all allocated data pages so that we can free all of our memory. + */ + private final List<MemoryBlock> dataPages = new LinkedList<MemoryBlock>(); + + /** + * The data page that will be used to store keys and values for new hashtable entries. When this + * page becomes full, a new page will be allocated and this pointer will change to point to that + * new page. + */ + private MemoryBlock currentDataPage = null; + + /** + * Offset into `currentDataPage` that points to the location where new data can be inserted into + * the page. + */ + private long pageCursor = 0; + + /** + * The size of the data pages that hold key and value data. Map entries cannot span multiple + * pages, so this limits the maximum entry size. + */ + private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes + + // This choice of page table size and page size means that we can address up to 500 gigabytes + // of memory. + + /** + * A single array to store the key and value. + * + * Position {@code 2 * i} in the array is used to track a pointer to the key at index {@code i}, + * while position {@code 2 * i + 1} in the array holds key's full 32-bit hashcode. + */ + private LongArray longArray; + // TODO: we're wasting 32 bits of space here; we can probably store fewer bits of the hashcode + // and exploit word-alignment to use fewer bits to hold the address. This might let us store + // only one long per map entry, increasing the chance that this array will fit in cache at the + // expense of maybe performing more lookups if we have hash collisions. Say that we stored only + // 27 bits of the hashcode and 37 bits of the address. 37 bits is enough to address 1 terabyte + // of RAM given word-alignment. If we use 13 bits of this for our page table, that gives us a + // maximum page size of 2^24 * 8 = ~134 megabytes per page. This change will require us to store + // full base addresses in the page table for off-heap mode so that we can reconstruct the full + // absolute memory addresses. + + /** + * A {@link BitSet} used to track location of the map where the key is set. + * Size of the bitset should be half of the size of the long array. + */ + private BitSet bitset; + + private final double loadFactor; + + /** + * Number of keys defined in the map. + */ + private int size; + + /** + * The map will be expanded once the number of keys exceeds this threshold. + */ + private int growthThreshold; + + /** + * Mask for truncating hashcodes so that they do not exceed the long array's size. + * This is a strength reduction optimization; we're essentially performing a modulus operation, + * but doing so with a bitmask because this is a power-of-2-sized hash map. + */ + private int mask; + + /** + * Return value of {@link BytesToBytesMap#lookup(Object, long, int)}. + */ + private final Location loc; + + private final boolean enablePerfMetrics; + + private long timeSpentResizingNs = 0; + + private long numProbes = 0; + + private long numKeyLookups = 0; + + private long numHashCollisions = 0; + + public BytesToBytesMap( + TaskMemoryManager memoryManager, + int initialCapacity, + double loadFactor, + boolean enablePerfMetrics) { + this.memoryManager = memoryManager; + this.loadFactor = loadFactor; + this.loc = new Location(); + this.enablePerfMetrics = enablePerfMetrics; + allocate(initialCapacity); + } + + public BytesToBytesMap(TaskMemoryManager memoryManager, int initialCapacity) { + this(memoryManager, initialCapacity, 0.70, false); + } + + public BytesToBytesMap( + TaskMemoryManager memoryManager, + int initialCapacity, + boolean enablePerfMetrics) { + this(memoryManager, initialCapacity, 0.70, enablePerfMetrics); + } + + /** + * Returns the number of keys defined in the map. + */ + public int size() { return size; } + + /** + * Returns an iterator for iterating over the entries of this map. + * + * For efficiency, all calls to `next()` will return the same {@link Location} object. + * + * If any other lookups or operations are performed on this map while iterating over it, including + * `lookup()`, the behavior of the returned iterator is undefined. + */ + public Iterator<Location> iterator() { + return new Iterator<Location>() { + + private int nextPos = bitset.nextSetBit(0); + + @Override + public boolean hasNext() { + return nextPos != -1; + } + + @Override + public Location next() { + final int pos = nextPos; + nextPos = bitset.nextSetBit(nextPos + 1); + return loc.with(pos, 0, true); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + + /** + * Looks up a key, and return a {@link Location} handle that can be used to test existence + * and read/write values. + * + * This function always return the same {@link Location} instance to avoid object allocation. + */ + public Location lookup( + Object keyBaseObject, + long keyBaseOffset, + int keyRowLengthBytes) { + if (enablePerfMetrics) { + numKeyLookups++; + } + final int hashcode = HASHER.hashUnsafeWords(keyBaseObject, keyBaseOffset, keyRowLengthBytes); + int pos = hashcode & mask; + int step = 1; + while (true) { + if (enablePerfMetrics) { + numProbes++; + } + if (!bitset.isSet(pos)) { + // This is a new key. + return loc.with(pos, hashcode, false); + } else { + long stored = longArray.get(pos * 2 + 1); + if ((int) (stored) == hashcode) { + // Full hash code matches. Let's compare the keys for equality. + loc.with(pos, hashcode, true); + if (loc.getKeyLength() == keyRowLengthBytes) { + final MemoryLocation keyAddress = loc.getKeyAddress(); + final Object storedKeyBaseObject = keyAddress.getBaseObject(); + final long storedKeyBaseOffset = keyAddress.getBaseOffset(); + final boolean areEqual = ByteArrayMethods.wordAlignedArrayEquals( + keyBaseObject, + keyBaseOffset, + storedKeyBaseObject, + storedKeyBaseOffset, + keyRowLengthBytes + ); + if (areEqual) { + return loc; + } else { + if (enablePerfMetrics) { + numHashCollisions++; + } + } + } + } + } + pos = (pos + step) & mask; + step++; + } + } + + /** + * Handle returned by {@link BytesToBytesMap#lookup(Object, long, int)} function. + */ + public final class Location { + /** An index into the hash map's Long array */ + private int pos; + /** True if this location points to a position where a key is defined, false otherwise */ + private boolean isDefined; + /** + * The hashcode of the most recent key passed to + * {@link BytesToBytesMap#lookup(Object, long, int)}. Caching this hashcode here allows us to + * avoid re-hashing the key when storing a value for that key. + */ + private int keyHashcode; + private final MemoryLocation keyMemoryLocation = new MemoryLocation(); + private final MemoryLocation valueMemoryLocation = new MemoryLocation(); + private int keyLength; + private int valueLength; + + private void updateAddressesAndSizes(long fullKeyAddress) { + final Object page = memoryManager.getPage(fullKeyAddress); + final long keyOffsetInPage = memoryManager.getOffsetInPage(fullKeyAddress); + long position = keyOffsetInPage; + keyLength = (int) PlatformDependent.UNSAFE.getLong(page, position); + position += 8; // word used to store the key size + keyMemoryLocation.setObjAndOffset(page, position); + position += keyLength; + valueLength = (int) PlatformDependent.UNSAFE.getLong(page, position); + position += 8; // word used to store the key size + valueMemoryLocation.setObjAndOffset(page, position); + } + + Location with(int pos, int keyHashcode, boolean isDefined) { + this.pos = pos; + this.isDefined = isDefined; + this.keyHashcode = keyHashcode; + if (isDefined) { + final long fullKeyAddress = longArray.get(pos * 2); + updateAddressesAndSizes(fullKeyAddress); + } + return this; + } + + /** + * Returns true if the key is defined at this position, and false otherwise. + */ + public boolean isDefined() { + return isDefined; + } + + /** + * Returns the address of the key defined at this position. + * This points to the first byte of the key data. + * Unspecified behavior if the key is not defined. + * For efficiency reasons, calls to this method always returns the same MemoryLocation object. + */ + public MemoryLocation getKeyAddress() { + assert (isDefined); + return keyMemoryLocation; + } + + /** + * Returns the length of the key defined at this position. + * Unspecified behavior if the key is not defined. + */ + public int getKeyLength() { + assert (isDefined); + return keyLength; + } + + /** + * Returns the address of the value defined at this position. + * This points to the first byte of the value data. + * Unspecified behavior if the key is not defined. + * For efficiency reasons, calls to this method always returns the same MemoryLocation object. + */ + public MemoryLocation getValueAddress() { + assert (isDefined); + return valueMemoryLocation; + } + + /** + * Returns the length of the value defined at this position. + * Unspecified behavior if the key is not defined. + */ + public int getValueLength() { + assert (isDefined); + return valueLength; + } + + /** + * Store a new key and value. This method may only be called once for a given key; if you want + * to update the value associated with a key, then you can directly manipulate the bytes stored + * at the value address. + * <p> + * It is only valid to call this method immediately after calling `lookup()` using the same key. + * <p> + * After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length` + * will return information on the data stored by this `putNewKey` call. + * <p> + * As an example usage, here's the proper way to store a new key: + * <p> + * <pre> + * Location loc = map.lookup(keyBaseOffset, keyBaseObject, keyLengthInBytes); + * if (!loc.isDefined()) { + * loc.putNewKey(keyBaseOffset, keyBaseObject, keyLengthInBytes, ...) + * } + * </pre> + * <p> + * Unspecified behavior if the key is not defined. + */ + public void putNewKey( + Object keyBaseObject, + long keyBaseOffset, + int keyLengthBytes, + Object valueBaseObject, + long valueBaseOffset, + int valueLengthBytes) { + assert (!isDefined) : "Can only set value once for a key"; + isDefined = true; + assert (keyLengthBytes % 8 == 0); + assert (valueLengthBytes % 8 == 0); + // Here, we'll copy the data into our data pages. Because we only store a relative offset from + // the key address instead of storing the absolute address of the value, the key and value + // must be stored in the same memory page. + // (8 byte key length) (key) (8 byte value length) (value) + final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes; + assert(requiredSize <= PAGE_SIZE_BYTES); + size++; + bitset.set(pos); + + // If there's not enough space in the current page, allocate a new page: + if (currentDataPage == null || PAGE_SIZE_BYTES - pageCursor < requiredSize) { + MemoryBlock newPage = memoryManager.allocatePage(PAGE_SIZE_BYTES); + dataPages.add(newPage); + pageCursor = 0; + currentDataPage = newPage; + } + + // Compute all of our offsets up-front: + final Object pageBaseObject = currentDataPage.getBaseObject(); + final long pageBaseOffset = currentDataPage.getBaseOffset(); + final long keySizeOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += 8; // word used to store the key size + final long keyDataOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += keyLengthBytes; + final long valueSizeOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += 8; // word used to store the value size + final long valueDataOffsetInPage = pageBaseOffset + pageCursor; + pageCursor += valueLengthBytes; + + // Copy the key + PlatformDependent.UNSAFE.putLong(pageBaseObject, keySizeOffsetInPage, keyLengthBytes); + PlatformDependent.UNSAFE.copyMemory( + keyBaseObject, keyBaseOffset, pageBaseObject, keyDataOffsetInPage, keyLengthBytes); + // Copy the value + PlatformDependent.UNSAFE.putLong(pageBaseObject, valueSizeOffsetInPage, valueLengthBytes); + PlatformDependent.UNSAFE.copyMemory( + valueBaseObject, valueBaseOffset, pageBaseObject, valueDataOffsetInPage, valueLengthBytes); + + final long storedKeyAddress = memoryManager.encodePageNumberAndOffset( + currentDataPage, keySizeOffsetInPage); + longArray.set(pos * 2, storedKeyAddress); + longArray.set(pos * 2 + 1, keyHashcode); + updateAddressesAndSizes(storedKeyAddress); + isDefined = true; + if (size > growthThreshold) { + growAndRehash(); + } + } + } + + /** + * Allocate new data structures for this map. When calling this outside of the constructor, + * make sure to keep references to the old data structures so that you can free them. + * + * @param capacity the new map capacity + */ + private void allocate(int capacity) { + capacity = Math.max((int) Math.min(Integer.MAX_VALUE, nextPowerOf2(capacity)), 64); + longArray = new LongArray(memoryManager.allocate(capacity * 8 * 2)); + bitset = new BitSet(memoryManager.allocate(capacity / 8).zero()); + + this.growthThreshold = (int) (capacity * loadFactor); + this.mask = capacity - 1; + } + + /** + * Free all allocated memory associated with this map, including the storage for keys and values + * as well as the hash map array itself. + * + * This method is idempotent. + */ + public void free() { + if (longArray != null) { + memoryManager.free(longArray.memoryBlock()); + longArray = null; + } + if (bitset != null) { + memoryManager.free(bitset.memoryBlock()); + bitset = null; + } + Iterator<MemoryBlock> dataPagesIterator = dataPages.iterator(); + while (dataPagesIterator.hasNext()) { + memoryManager.freePage(dataPagesIterator.next()); + dataPagesIterator.remove(); + } + assert(dataPages.isEmpty()); + } + + /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */ + public long getTotalMemoryConsumption() { + return ( + dataPages.size() * PAGE_SIZE_BYTES + + bitset.memoryBlock().size() + + longArray.memoryBlock().size()); + } + + /** + * Returns the total amount of time spent resizing this map (in nanoseconds). + */ + public long getTimeSpentResizingNs() { + if (!enablePerfMetrics) { + throw new IllegalStateException(); + } + return timeSpentResizingNs; + } + + + /** + * Returns the average number of probes per key lookup. + */ + public double getAverageProbesPerLookup() { + if (!enablePerfMetrics) { + throw new IllegalStateException(); + } + return (1.0 * numProbes) / numKeyLookups; + } + + public long getNumHashCollisions() { + if (!enablePerfMetrics) { + throw new IllegalStateException(); + } + return numHashCollisions; + } + + /** + * Grows the size of the hash table and re-hash everything. + */ + private void growAndRehash() { + long resizeStartTime = -1; + if (enablePerfMetrics) { + resizeStartTime = System.nanoTime(); + } + // Store references to the old data structures to be used when we re-hash + final LongArray oldLongArray = longArray; + final BitSet oldBitSet = bitset; + final int oldCapacity = (int) oldBitSet.capacity(); + + // Allocate the new data structures + allocate(Math.min(Integer.MAX_VALUE, growthStrategy.nextCapacity(oldCapacity))); + + // Re-mask (we don't recompute the hashcode because we stored all 32 bits of it) + for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) { + final long keyPointer = oldLongArray.get(pos * 2); + final int hashcode = (int) oldLongArray.get(pos * 2 + 1); + int newPos = hashcode & mask; + int step = 1; + boolean keepGoing = true; + + // No need to check for equality here when we insert so this has one less if branch than + // the similar code path in addWithoutResize. + while (keepGoing) { + if (!bitset.isSet(newPos)) { + bitset.set(newPos); + longArray.set(newPos * 2, keyPointer); + longArray.set(newPos * 2 + 1, hashcode); + keepGoing = false; + } else { + newPos = (newPos + step) & mask; + step++; + } + } + } + + // Deallocate the old data structures. + memoryManager.free(oldLongArray.memoryBlock()); + memoryManager.free(oldBitSet.memoryBlock()); + if (enablePerfMetrics) { + timeSpentResizingNs += System.nanoTime() - resizeStartTime; + } + } + + /** Returns the next number greater or equal num that is power of 2. */ + private static long nextPowerOf2(long num) { + final long highBit = Long.highestOneBit(num); + return (highBit == num) ? num : highBit << 1; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java new file mode 100644 index 0000000000000000000000000000000000000000..7c321baffe82d5cd31a23fa968e646b44b497c79 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +/** + * Interface that defines how we can grow the size of a hash map when it is over a threshold. + */ +public interface HashMapGrowthStrategy { + + int nextCapacity(int currentCapacity); + + /** + * Double the size of the hash map every time. + */ + HashMapGrowthStrategy DOUBLING = new Doubling(); + + class Doubling implements HashMapGrowthStrategy { + @Override + public int nextCapacity(int currentCapacity) { + return currentCapacity * 2; + } + } + +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java new file mode 100644 index 0000000000000000000000000000000000000000..62c29c8cc1e4d1068f42b4df8bbe3fdc05cb3910 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +/** + * Manages memory for an executor. Individual operators / tasks allocate memory through + * {@link TaskMemoryManager} objects, which obtain their memory from ExecutorMemoryManager. + */ +public class ExecutorMemoryManager { + + /** + * Allocator, exposed for enabling untracked allocations of temporary data structures. + */ + public final MemoryAllocator allocator; + + /** + * Tracks whether memory will be allocated on the JVM heap or off-heap using sun.misc.Unsafe. + */ + final boolean inHeap; + + /** + * Construct a new ExecutorMemoryManager. + * + * @param allocator the allocator that will be used + */ + public ExecutorMemoryManager(MemoryAllocator allocator) { + this.inHeap = allocator instanceof HeapMemoryAllocator; + this.allocator = allocator; + } + + /** + * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed + * to be zeroed out (call `zero()` on the result if this is necessary). + */ + MemoryBlock allocate(long size) throws OutOfMemoryError { + return allocator.allocate(size); + } + + void free(MemoryBlock memory) { + allocator.free(memory); + } + +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java new file mode 100644 index 0000000000000000000000000000000000000000..bbe83d36cf36ba8010dcae85e0c5f90898647337 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +/** + * A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array. + */ +public class HeapMemoryAllocator implements MemoryAllocator { + + @Override + public MemoryBlock allocate(long size) throws OutOfMemoryError { + long[] array = new long[(int) (size / 8)]; + return MemoryBlock.fromLongArray(array); + } + + @Override + public void free(MemoryBlock memory) { + // Do nothing + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java new file mode 100644 index 0000000000000000000000000000000000000000..5192f68c862cfc9d60b674a6fecf7ec6e7ae36bd --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +public interface MemoryAllocator { + + /** + * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed + * to be zeroed out (call `zero()` on the result if this is necessary). + */ + MemoryBlock allocate(long size) throws OutOfMemoryError; + + void free(MemoryBlock memory); + + MemoryAllocator UNSAFE = new UnsafeMemoryAllocator(); + + MemoryAllocator HEAP = new HeapMemoryAllocator(); +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java new file mode 100644 index 0000000000000000000000000000000000000000..0beb743e5644eea312eedd235f2bde1a289ffd12 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import javax.annotation.Nullable; + +import org.apache.spark.unsafe.PlatformDependent; + +/** + * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size. + */ +public class MemoryBlock extends MemoryLocation { + + private final long length; + + /** + * Optional page number; used when this MemoryBlock represents a page allocated by a + * MemoryManager. This is package-private and is modified by MemoryManager. + */ + int pageNumber = -1; + + MemoryBlock(@Nullable Object obj, long offset, long length) { + super(obj, offset); + this.length = length; + } + + /** + * Returns the size of the memory block. + */ + public long size() { + return length; + } + + /** + * Clear the contents of this memory block. Returns `this` to facilitate chaining. + */ + public MemoryBlock zero() { + PlatformDependent.UNSAFE.setMemory(obj, offset, length, (byte) 0); + return this; + } + + /** + * Creates a memory block pointing to the memory used by the long array. + */ + public static MemoryBlock fromLongArray(final long[] array) { + return new MemoryBlock(array, PlatformDependent.LONG_ARRAY_OFFSET, array.length * 8); + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java new file mode 100644 index 0000000000000000000000000000000000000000..74ebc87dc978c50ddd35e81e4766a6172aad5d23 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import javax.annotation.Nullable; + +/** + * A memory location. Tracked either by a memory address (with off-heap allocation), + * or by an offset from a JVM object (in-heap allocation). + */ +public class MemoryLocation { + + @Nullable + Object obj; + + long offset; + + public MemoryLocation(@Nullable Object obj, long offset) { + this.obj = obj; + this.offset = offset; + } + + public MemoryLocation() { + this(null, 0); + } + + public void setObjAndOffset(Object newObj, long newOffset) { + this.obj = newObj; + this.offset = newOffset; + } + + public final Object getBaseObject() { + return obj; + } + + public final long getBaseOffset() { + return offset; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java new file mode 100644 index 0000000000000000000000000000000000000000..9224988e6ad69cd0561a27618e7cd92979c6d101 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import java.util.*; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Manages the memory allocated by an individual task. + * <p> + * Most of the complexity in this class deals with encoding of off-heap addresses into 64-bit longs. + * In off-heap mode, memory can be directly addressed with 64-bit longs. In on-heap mode, memory is + * addressed by the combination of a base Object reference and a 64-bit offset within that object. + * This is a problem when we want to store pointers to data structures inside of other structures, + * such as record pointers inside hashmaps or sorting buffers. Even if we decided to use 128 bits + * to address memory, we can't just store the address of the base object since it's not guaranteed + * to remain stable as the heap gets reorganized due to GC. + * <p> + * Instead, we use the following approach to encode record pointers in 64-bit longs: for off-heap + * mode, just store the raw address, and for on-heap mode use the upper 13 bits of the address to + * store a "page number" and the lower 51 bits to store an offset within this page. These page + * numbers are used to index into a "page table" array inside of the MemoryManager in order to + * retrieve the base object. + * <p> + * This allows us to address 8192 pages. In on-heap mode, the maximum page size is limited by the + * maximum size of a long[] array, allowing us to address 8192 * 2^32 * 8 bytes, which is + * approximately 35 terabytes of memory. + */ +public final class TaskMemoryManager { + + private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class); + + /** + * The number of entries in the page table. + */ + private static final int PAGE_TABLE_SIZE = 1 << 13; + + /** Bit mask for the lower 51 bits of a long. */ + private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; + + /** Bit mask for the upper 13 bits of a long */ + private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS; + + /** + * Similar to an operating system's page table, this array maps page numbers into base object + * pointers, allowing us to translate between the hashtable's internal 64-bit address + * representation and the baseObject+offset representation which we use to support both in- and + * off-heap addresses. When using an off-heap allocator, every entry in this map will be `null`. + * When using an in-heap allocator, the entries in this map will point to pages' base objects. + * Entries are added to this map as new data pages are allocated. + */ + private final MemoryBlock[] pageTable = new MemoryBlock[PAGE_TABLE_SIZE]; + + /** + * Bitmap for tracking free pages. + */ + private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE); + + /** + * Tracks memory allocated with {@link TaskMemoryManager#allocate(long)}, used to detect / clean + * up leaked memory. + */ + private final HashSet<MemoryBlock> allocatedNonPageMemory = new HashSet<MemoryBlock>(); + + private final ExecutorMemoryManager executorMemoryManager; + + /** + * Tracks whether we're in-heap or off-heap. For off-heap, we short-circuit most of these methods + * without doing any masking or lookups. Since this branching should be well-predicted by the JIT, + * this extra layer of indirection / abstraction hopefully shouldn't be too expensive. + */ + private final boolean inHeap; + + /** + * Construct a new MemoryManager. + */ + public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) { + this.inHeap = executorMemoryManager.inHeap; + this.executorMemoryManager = executorMemoryManager; + } + + /** + * Allocate a block of memory that will be tracked in the MemoryManager's page table; this is + * intended for allocating large blocks of memory that will be shared between operators. + */ + public MemoryBlock allocatePage(long size) { + if (logger.isTraceEnabled()) { + logger.trace("Allocating {} byte page", size); + } + if (size >= (1L << 51)) { + throw new IllegalArgumentException("Cannot allocate a page with more than 2^51 bytes"); + } + + final int pageNumber; + synchronized (this) { + pageNumber = allocatedPages.nextClearBit(0); + if (pageNumber >= PAGE_TABLE_SIZE) { + throw new IllegalStateException( + "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages"); + } + allocatedPages.set(pageNumber); + } + final MemoryBlock page = executorMemoryManager.allocate(size); + page.pageNumber = pageNumber; + pageTable[pageNumber] = page; + if (logger.isDebugEnabled()) { + logger.debug("Allocate page number {} ({} bytes)", pageNumber, size); + } + return page; + } + + /** + * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}. + */ + public void freePage(MemoryBlock page) { + if (logger.isTraceEnabled()) { + logger.trace("Freeing page number {} ({} bytes)", page.pageNumber, page.size()); + } + assert (page.pageNumber != -1) : + "Called freePage() on memory that wasn't allocated with allocatePage()"; + executorMemoryManager.free(page); + synchronized (this) { + allocatedPages.clear(page.pageNumber); + } + pageTable[page.pageNumber] = null; + if (logger.isDebugEnabled()) { + logger.debug("Freed page number {} ({} bytes)", page.pageNumber, page.size()); + } + } + + /** + * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed + * to be zeroed out (call `zero()` on the result if this is necessary). This method is intended + * to be used for allocating operators' internal data structures. For data pages that you want to + * exchange between operators, consider using {@link TaskMemoryManager#allocatePage(long)}, since + * that will enable intra-memory pointers (see + * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} and this class's + * top-level Javadoc for more details). + */ + public MemoryBlock allocate(long size) throws OutOfMemoryError { + final MemoryBlock memory = executorMemoryManager.allocate(size); + allocatedNonPageMemory.add(memory); + return memory; + } + + /** + * Free memory allocated by {@link TaskMemoryManager#allocate(long)}. + */ + public void free(MemoryBlock memory) { + assert (memory.pageNumber == -1) : "Should call freePage() for pages, not free()"; + executorMemoryManager.free(memory); + final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory); + assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!"; + } + + /** + * Given a memory page and offset within that page, encode this address into a 64-bit long. + * This address will remain valid as long as the corresponding page has not been freed. + */ + public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { + if (inHeap) { + assert (page.pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; + return (((long) page.pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS); + } else { + return offsetInPage; + } + } + + /** + * Get the page associated with an address encoded by + * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} + */ + public Object getPage(long pagePlusOffsetAddress) { + if (inHeap) { + final int pageNumber = (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> 51); + assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); + final Object page = pageTable[pageNumber].getBaseObject(); + assert (page != null); + return page; + } else { + return null; + } + } + + /** + * Get the offset associated with an address encoded by + * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} + */ + public long getOffsetInPage(long pagePlusOffsetAddress) { + if (inHeap) { + return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); + } else { + return pagePlusOffsetAddress; + } + } + + /** + * Clean up all allocated memory and pages. Returns the number of bytes freed. A non-zero return + * value can be used to detect memory leaks. + */ + public long cleanUpAllAllocatedMemory() { + long freedBytes = 0; + for (MemoryBlock page : pageTable) { + if (page != null) { + freedBytes += page.size(); + freePage(page); + } + } + final Iterator<MemoryBlock> iter = allocatedNonPageMemory.iterator(); + while (iter.hasNext()) { + final MemoryBlock memory = iter.next(); + freedBytes += memory.size(); + // We don't call free() here because that calls Set.remove, which would lead to a + // ConcurrentModificationException here. + executorMemoryManager.free(memory); + iter.remove(); + } + return freedBytes; + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java new file mode 100644 index 0000000000000000000000000000000000000000..15898771fef25350071721444d0d06e5befa83d1 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.apache.spark.unsafe.PlatformDependent; + +/** + * A simple {@link MemoryAllocator} that uses {@code Unsafe} to allocate off-heap memory. + */ +public class UnsafeMemoryAllocator implements MemoryAllocator { + + @Override + public MemoryBlock allocate(long size) throws OutOfMemoryError { + long address = PlatformDependent.UNSAFE.allocateMemory(size); + return new MemoryBlock(null, address, size); + } + + @Override + public void free(MemoryBlock memory) { + assert (memory.obj == null) : + "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?"; + PlatformDependent.UNSAFE.freeMemory(memory.offset); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java new file mode 100644 index 0000000000000000000000000000000000000000..5974cf91ff993cace8949c6611a39e4f3649b27e --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.array; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.unsafe.memory.MemoryBlock; + +public class LongArraySuite { + + @Test + public void basicTest() { + long[] bytes = new long[2]; + LongArray arr = new LongArray(MemoryBlock.fromLongArray(bytes)); + arr.set(0, 1L); + arr.set(1, 2L); + arr.set(1, 3L); + Assert.assertEquals(2, arr.size()); + Assert.assertEquals(1L, arr.get(0)); + Assert.assertEquals(3L, arr.get(1)); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..4bf132fd4053e01320ecf006b6c4b86535de6fd9 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/BitSetSuite.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.bitset; + +import junit.framework.Assert; +import org.apache.spark.unsafe.bitset.BitSet; +import org.junit.Test; + +import org.apache.spark.unsafe.memory.MemoryBlock; + +public class BitSetSuite { + + private static BitSet createBitSet(int capacity) { + assert capacity % 64 == 0; + return new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64]).zero()); + } + + @Test + public void basicOps() { + BitSet bs = createBitSet(64); + Assert.assertEquals(64, bs.capacity()); + + // Make sure the bit set starts empty. + for (int i = 0; i < bs.capacity(); i++) { + Assert.assertFalse(bs.isSet(i)); + } + + // Set every bit and check it. + for (int i = 0; i < bs.capacity(); i++) { + bs.set(i); + Assert.assertTrue(bs.isSet(i)); + } + + // Unset every bit and check it. + for (int i = 0; i < bs.capacity(); i++) { + Assert.assertTrue(bs.isSet(i)); + bs.unset(i); + Assert.assertFalse(bs.isSet(i)); + } + } + + @Test + public void traversal() { + BitSet bs = createBitSet(256); + + Assert.assertEquals(-1, bs.nextSetBit(0)); + Assert.assertEquals(-1, bs.nextSetBit(10)); + Assert.assertEquals(-1, bs.nextSetBit(64)); + + bs.set(10); + Assert.assertEquals(10, bs.nextSetBit(0)); + Assert.assertEquals(10, bs.nextSetBit(1)); + Assert.assertEquals(10, bs.nextSetBit(10)); + Assert.assertEquals(-1, bs.nextSetBit(11)); + + bs.set(11); + Assert.assertEquals(10, bs.nextSetBit(10)); + Assert.assertEquals(11, bs.nextSetBit(11)); + + // Skip a whole word and find it + bs.set(190); + Assert.assertEquals(190, bs.nextSetBit(12)); + + Assert.assertEquals(-1, bs.nextSetBit(191)); + Assert.assertEquals(-1, bs.nextSetBit(256)); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java new file mode 100644 index 0000000000000000000000000000000000000000..3b9175835229cbfc5b0d9e8146af8f9f5f4c9453 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/hash/Murmur3_x86_32Suite.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.hash; + +import java.util.HashSet; +import java.util.Random; +import java.util.Set; + +import junit.framework.Assert; +import org.apache.spark.unsafe.PlatformDependent; +import org.junit.Test; + +/** + * Test file based on Guava's Murmur3Hash32Test. + */ +public class Murmur3_x86_32Suite { + + private static final Murmur3_x86_32 hasher = new Murmur3_x86_32(0); + + @Test + public void testKnownIntegerInputs() { + Assert.assertEquals(593689054, hasher.hashInt(0)); + Assert.assertEquals(-189366624, hasher.hashInt(-42)); + Assert.assertEquals(-1134849565, hasher.hashInt(42)); + Assert.assertEquals(-1718298732, hasher.hashInt(Integer.MIN_VALUE)); + Assert.assertEquals(-1653689534, hasher.hashInt(Integer.MAX_VALUE)); + } + + @Test + public void testKnownLongInputs() { + Assert.assertEquals(1669671676, hasher.hashLong(0L)); + Assert.assertEquals(-846261623, hasher.hashLong(-42L)); + Assert.assertEquals(1871679806, hasher.hashLong(42L)); + Assert.assertEquals(1366273829, hasher.hashLong(Long.MIN_VALUE)); + Assert.assertEquals(-2106506049, hasher.hashLong(Long.MAX_VALUE)); + } + + @Test + public void randomizedStressTest() { + int size = 65536; + Random rand = new Random(); + + // A set used to track collision rate. + Set<Integer> hashcodes = new HashSet<Integer>(); + for (int i = 0; i < size; i++) { + int vint = rand.nextInt(); + long lint = rand.nextLong(); + Assert.assertEquals(hasher.hashInt(vint), hasher.hashInt(vint)); + Assert.assertEquals(hasher.hashLong(lint), hasher.hashLong(lint)); + + hashcodes.add(hasher.hashLong(lint)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95); + } + + @Test + public void randomizedStressTestBytes() { + int size = 65536; + Random rand = new Random(); + + // A set used to track collision rate. + Set<Integer> hashcodes = new HashSet<Integer>(); + for (int i = 0; i < size; i++) { + int byteArrSize = rand.nextInt(100) * 8; + byte[] bytes = new byte[byteArrSize]; + rand.nextBytes(bytes); + + Assert.assertEquals( + hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + + hashcodes.add(hasher.hashUnsafeWords( + bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95); + } + + @Test + public void randomizedStressTestPaddedStrings() { + int size = 64000; + // A set used to track collision rate. + Set<Integer> hashcodes = new HashSet<Integer>(); + for (int i = 0; i < size; i++) { + int byteArrSize = 8; + byte[] strBytes = ("" + i).getBytes(); + byte[] paddedBytes = new byte[byteArrSize]; + System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length); + + Assert.assertEquals( + hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize), + hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + + hashcodes.add(hasher.hashUnsafeWords( + paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize)); + } + + // A very loose bound. + Assert.assertTrue(hashcodes.size() > size * 0.95); + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..9038cf567f1e2d5e3b3f0c221acf6bcfdb9024d2 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +import java.lang.Exception; +import java.nio.ByteBuffer; +import java.util.*; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.PlatformDependent; +import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET; +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; +import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.unsafe.memory.MemoryLocation; +import org.apache.spark.unsafe.memory.TaskMemoryManager; + +public abstract class AbstractBytesToBytesMapSuite { + + private final Random rand = new Random(42); + + private TaskMemoryManager memoryManager; + + @Before + public void setup() { + memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator())); + } + + @After + public void tearDown() { + if (memoryManager != null) { + memoryManager.cleanUpAllAllocatedMemory(); + memoryManager = null; + } + } + + protected abstract MemoryAllocator getMemoryAllocator(); + + private static byte[] getByteArray(MemoryLocation loc, int size) { + final byte[] arr = new byte[size]; + PlatformDependent.UNSAFE.copyMemory( + loc.getBaseObject(), + loc.getBaseOffset(), + arr, + BYTE_ARRAY_OFFSET, + size + ); + return arr; + } + + private byte[] getRandomByteArray(int numWords) { + Assert.assertTrue(numWords > 0); + final int lengthInBytes = numWords * 8; + final byte[] bytes = new byte[lengthInBytes]; + rand.nextBytes(bytes); + return bytes; + } + + /** + * Fast equality checking for byte arrays, since these comparisons are a bottleneck + * in our stress tests. + */ + private static boolean arrayEquals( + byte[] expected, + MemoryLocation actualAddr, + long actualLengthBytes) { + return (actualLengthBytes == expected.length) && ByteArrayMethods.wordAlignedArrayEquals( + expected, + BYTE_ARRAY_OFFSET, + actualAddr.getBaseObject(), + actualAddr.getBaseOffset(), + expected.length + ); + } + + @Test + public void emptyMap() { + BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64); + try { + Assert.assertEquals(0, map.size()); + final int keyLengthInWords = 10; + final int keyLengthInBytes = keyLengthInWords * 8; + final byte[] key = getRandomByteArray(keyLengthInWords); + Assert.assertFalse(map.lookup(key, BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined()); + } finally { + map.free(); + } + } + + @Test + public void setAndRetrieveAKey() { + BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64); + final int recordLengthWords = 10; + final int recordLengthBytes = recordLengthWords * 8; + final byte[] keyData = getRandomByteArray(recordLengthWords); + final byte[] valueData = getRandomByteArray(recordLengthWords); + try { + final BytesToBytesMap.Location loc = + map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes); + Assert.assertFalse(loc.isDefined()); + loc.putNewKey( + keyData, + BYTE_ARRAY_OFFSET, + recordLengthBytes, + valueData, + BYTE_ARRAY_OFFSET, + recordLengthBytes + ); + // After storing the key and value, the other location methods should return results that + // reflect the result of this store without us having to call lookup() again on the same key. + Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); + Assert.assertEquals(recordLengthBytes, loc.getValueLength()); + Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); + Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); + + // After calling lookup() the location should still point to the correct data. + Assert.assertTrue(map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined()); + Assert.assertEquals(recordLengthBytes, loc.getKeyLength()); + Assert.assertEquals(recordLengthBytes, loc.getValueLength()); + Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes)); + Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes)); + + try { + loc.putNewKey( + keyData, + BYTE_ARRAY_OFFSET, + recordLengthBytes, + valueData, + BYTE_ARRAY_OFFSET, + recordLengthBytes + ); + Assert.fail("Should not be able to set a new value for a key"); + } catch (AssertionError e) { + // Expected exception; do nothing. + } + } finally { + map.free(); + } + } + + @Test + public void iteratorTest() throws Exception { + final int size = 128; + BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2); + try { + for (long i = 0; i < size; i++) { + final long[] value = new long[] { i }; + final BytesToBytesMap.Location loc = + map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8); + Assert.assertFalse(loc.isDefined()); + loc.putNewKey( + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8, + value, + PlatformDependent.LONG_ARRAY_OFFSET, + 8 + ); + } + final java.util.BitSet valuesSeen = new java.util.BitSet(size); + final Iterator<BytesToBytesMap.Location> iter = map.iterator(); + while (iter.hasNext()) { + final BytesToBytesMap.Location loc = iter.next(); + Assert.assertTrue(loc.isDefined()); + final MemoryLocation keyAddress = loc.getKeyAddress(); + final MemoryLocation valueAddress = loc.getValueAddress(); + final long key = PlatformDependent.UNSAFE.getLong( + keyAddress.getBaseObject(), keyAddress.getBaseOffset()); + final long value = PlatformDependent.UNSAFE.getLong( + valueAddress.getBaseObject(), valueAddress.getBaseOffset()); + Assert.assertEquals(key, value); + valuesSeen.set((int) value); + } + Assert.assertEquals(size, valuesSeen.cardinality()); + } finally { + map.free(); + } + } + + @Test + public void randomizedStressTest() { + final int size = 65536; + // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays + // into ByteBuffers in order to use them as keys here. + final Map<ByteBuffer, byte[]> expected = new HashMap<ByteBuffer, byte[]>(); + final BytesToBytesMap map = new BytesToBytesMap(memoryManager, size); + + try { + // Fill the map to 90% full so that we can trigger probing + for (int i = 0; i < size * 0.9; i++) { + final byte[] key = getRandomByteArray(rand.nextInt(256) + 1); + final byte[] value = getRandomByteArray(rand.nextInt(512) + 1); + if (!expected.containsKey(ByteBuffer.wrap(key))) { + expected.put(ByteBuffer.wrap(key), value); + final BytesToBytesMap.Location loc = map.lookup( + key, + BYTE_ARRAY_OFFSET, + key.length + ); + Assert.assertFalse(loc.isDefined()); + loc.putNewKey( + key, + BYTE_ARRAY_OFFSET, + key.length, + value, + BYTE_ARRAY_OFFSET, + value.length + ); + // After calling putNewKey, the following should be true, even before calling + // lookup(): + Assert.assertTrue(loc.isDefined()); + Assert.assertEquals(key.length, loc.getKeyLength()); + Assert.assertEquals(value.length, loc.getValueLength()); + Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length)); + Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length)); + } + } + + for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) { + final byte[] key = entry.getKey().array(); + final byte[] value = entry.getValue(); + final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length); + Assert.assertTrue(loc.isDefined()); + Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength())); + Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength())); + } + } finally { + map.free(); + } + } +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..5a10de49f54fec4bb18367f04ef798708e76b3a9 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +import org.apache.spark.unsafe.memory.MemoryAllocator; + +public class BytesToBytesMapOffHeapSuite extends AbstractBytesToBytesMapSuite { + + @Override + protected MemoryAllocator getMemoryAllocator() { + return MemoryAllocator.UNSAFE; + } + +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..12cc9b25d93b308eafb9bc891a38ff45c27e1ade --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.map; + +import org.apache.spark.unsafe.memory.MemoryAllocator; + +public class BytesToBytesMapOnHeapSuite extends AbstractBytesToBytesMapSuite { + + @Override + protected MemoryAllocator getMemoryAllocator() { + return MemoryAllocator.HEAP; + } + +} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..932882f1ca24817147f060b9e1ae14ae31c7d408 --- /dev/null +++ b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.memory; + +import org.junit.Assert; +import org.junit.Test; + +public class TaskMemoryManagerSuite { + + @Test + public void leakedNonPageMemoryIsDetected() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + manager.allocate(1024); // leak memory + Assert.assertEquals(1024, manager.cleanUpAllAllocatedMemory()); + } + + @Test + public void leakedPageMemoryIsDetected() { + final TaskMemoryManager manager = + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + manager.allocatePage(4096); // leak memory + Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory()); + } + +}