diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index a78440df4f3e18e6da3516db912a0072fad00f12..57006bfaf9b695b9bdba30910a60a005a1523c72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.catalog import java.net.URI import java.util.Locale +import java.util.concurrent.Callable import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -125,14 +126,36 @@ class SessionCatalog( if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } - /** - * A cache of qualified table names to table relation plans. - */ - val tableRelationCache: Cache[QualifiedTableName, LogicalPlan] = { + private val tableRelationCache: Cache[QualifiedTableName, LogicalPlan] = { val cacheSize = conf.tableRelationCacheSize CacheBuilder.newBuilder().maximumSize(cacheSize).build[QualifiedTableName, LogicalPlan]() } + /** This method provides a way to get a cached plan. */ + def getCachedPlan(t: QualifiedTableName, c: Callable[LogicalPlan]): LogicalPlan = { + tableRelationCache.get(t, c) + } + + /** This method provides a way to get a cached plan if the key exists. */ + def getCachedTable(key: QualifiedTableName): LogicalPlan = { + tableRelationCache.getIfPresent(key) + } + + /** This method provides a way to cache a plan. */ + def cacheTable(t: QualifiedTableName, l: LogicalPlan): Unit = { + tableRelationCache.put(t, l) + } + + /** This method provides a way to invalidate a cached plan. */ + def invalidateCachedTable(key: QualifiedTableName): Unit = { + tableRelationCache.invalidate(key) + } + + /** This method provides a way to invalidate all the cached plans. */ + def invalidateAllCachedTables(): Unit = { + tableRelationCache.invalidateAll() + } + /** * This method is used to make the given path qualified before we * store this path in the underlying external catalog. So, when a path diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 21d75a404911b56ab246df1db51485a9c93c5f2b..e05a8d5f02bd8b187a29fecea2fee7a300dcae6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -215,9 +215,9 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] private def readDataSourceTable(r: CatalogRelation): LogicalPlan = { val table = r.tableMeta val qualifiedTableName = QualifiedTableName(table.database, table.identifier.table) - val cache = sparkSession.sessionState.catalog.tableRelationCache + val catalogProxy = sparkSession.sessionState.catalog - val plan = cache.get(qualifiedTableName, new Callable[LogicalPlan]() { + val plan = catalogProxy.getCachedPlan(qualifiedTableName, new Callable[LogicalPlan]() { override def call(): LogicalPlan = { val pathOption = table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_)) val dataSource = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 6b98066cb76c826d6f27b71b53fa64f17c2841c5..9b3cbb63a21b0347a1db4c6456fb213ebc9784f0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.types._ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Logging { // these are def_s and not val/lazy val since the latter would introduce circular references private def sessionState = sparkSession.sessionState - private def tableRelationCache = sparkSession.sessionState.catalog.tableRelationCache + private def catalogProxy = sparkSession.sessionState.catalog import HiveMetastoreCatalog._ /** These locks guard against multiple attempts to instantiate a table, which wastes memory. */ @@ -61,7 +61,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log val key = QualifiedTableName( table.database.getOrElse(sessionState.catalog.getCurrentDatabase).toLowerCase, table.table.toLowerCase) - tableRelationCache.getIfPresent(key) + catalogProxy.getCachedTable(key) } private def getCached( @@ -71,7 +71,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log expectedFileFormat: Class[_ <: FileFormat], partitionSchema: Option[StructType]): Option[LogicalRelation] = { - tableRelationCache.getIfPresent(tableIdentifier) match { + catalogProxy.getCachedTable(tableIdentifier) match { case null => None // Cache miss case logical @ LogicalRelation(relation: HadoopFsRelation, _, _) => val cachedRelationFileFormatClass = relation.fileFormat.getClass @@ -92,21 +92,21 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log Some(logical) } else { // If the cached relation is not updated, we invalidate it right away. - tableRelationCache.invalidate(tableIdentifier) + catalogProxy.invalidateCachedTable(tableIdentifier) None } case _ => logWarning(s"Table $tableIdentifier should be stored as $expectedFileFormat. " + s"However, we are getting a ${relation.fileFormat} from the metastore cache. " + "This cached entry will be invalidated.") - tableRelationCache.invalidate(tableIdentifier) + catalogProxy.invalidateCachedTable(tableIdentifier) None } case other => logWarning(s"Table $tableIdentifier should be stored as $expectedFileFormat. " + s"However, we are getting a $other from the metastore cache. " + "This cached entry will be invalidated.") - tableRelationCache.invalidate(tableIdentifier) + catalogProxy.invalidateCachedTable(tableIdentifier) None } } @@ -176,7 +176,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log fileFormat = fileFormat, options = options)(sparkSession = sparkSession) val created = LogicalRelation(fsRelation, updatedTable) - tableRelationCache.put(tableIdentifier, created) + catalogProxy.cacheTable(tableIdentifier, created) created } @@ -205,7 +205,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log className = fileType).resolveRelation(), table = updatedTable) - tableRelationCache.put(tableIdentifier, created) + catalogProxy.cacheTable(tableIdentifier, created) created } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala index b3a06045b5fd4d60adb759c8f0c61aec0290638c..d271acc63de087aa16e68410c783d1ebad2cfd79 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala @@ -46,7 +46,7 @@ class HiveSchemaInferenceSuite override def afterEach(): Unit = { super.afterEach() - spark.sessionState.catalog.tableRelationCache.invalidateAll() + spark.sessionState.catalog.invalidateAllCachedTables() FileStatusCache.resetForTesting() }