diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index 853dd0ff3f60103159dd372131b6a65473e17bab..26bd3fb7eb27b0c0a0947ac121a255367888f45e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -29,6 +29,7 @@ import org.scalatest.BeforeAndAfterAll
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
+import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE
 import org.apache.spark.sql.catalyst.FunctionIdentifier
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util._
@@ -196,7 +197,12 @@ private[sql] trait SQLTestUtils
       fail("Failed to create temporary database", cause)
     }
 
-    try f(dbName) finally spark.sql(s"DROP DATABASE $dbName CASCADE")
+    try f(dbName) finally {
+      if (spark.catalog.currentDatabase == dbName) {
+        spark.sql(s"USE ${DEFAULT_DATABASE}")
+      }
+      spark.sql(s"DROP DATABASE $dbName CASCADE")
+    }
   }
 
   /**