diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index f645eb5f7bb01b5ae446a02b6fb215a55453a106..063940cb9e2c3b58f1441c5a3c8c54dd56bc4d83 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -493,6 +493,50 @@ object DateTimeUtils {
     }
   }
 
+  /**
+   * Split date (expressed in days since 1.1.1970) into four fields:
+   * year, month (Jan is Month 1), dayInMonth, daysToMonthEnd (0 if it's last day of month).
+   */
+  def splitDate(date: Int): (Int, Int, Int, Int) = {
+    var (year, dayInYear) = getYearAndDayInYear(date)
+    val isLeap = isLeapYear(year)
+    if (isLeap && dayInYear == 60) {
+      (year, 2, 29, 0)
+    } else {
+      if (isLeap && dayInYear > 60) dayInYear -= 1
+
+      if (dayInYear <= 181) {
+        if (dayInYear <= 31) {
+          (year, 1, dayInYear, 31 - dayInYear)
+        } else if (dayInYear <= 59) {
+          (year, 2, dayInYear - 31, if (isLeap) 60 - dayInYear else 59 - dayInYear)
+        } else if (dayInYear <= 90) {
+          (year, 3, dayInYear - 59, 90 - dayInYear)
+        } else if (dayInYear <= 120) {
+          (year, 4, dayInYear - 90, 120 - dayInYear)
+        } else if (dayInYear <= 151) {
+          (year, 5, dayInYear - 120, 151 - dayInYear)
+        } else {
+          (year, 6, dayInYear - 151, 181 - dayInYear)
+        }
+      } else {
+        if (dayInYear <= 212) {
+          (year, 7, dayInYear - 181, 212 - dayInYear)
+        } else if (dayInYear <= 243) {
+          (year, 8, dayInYear - 212, 243 - dayInYear)
+        } else if (dayInYear <= 273) {
+          (year, 9, dayInYear - 243, 273 - dayInYear)
+        } else if (dayInYear <= 304) {
+          (year, 10, dayInYear - 273, 304 - dayInYear)
+        } else if (dayInYear <= 334) {
+          (year, 11, dayInYear - 304, 334 - dayInYear)
+        } else {
+          (year, 12, dayInYear - 334, 365 - dayInYear)
+        }
+      }
+    }
+  }
+
   /**
    * Returns the month value for the given date. The date is expressed in days
    * since 1.1.1970. January is month 1.
@@ -613,15 +657,16 @@ object DateTimeUtils {
    * Returns a date value, expressed in days since 1.1.1970.
    */
   def dateAddMonths(days: Int, months: Int): Int = {
-    val absoluteMonth = (getYear(days) - YearZero) * 12 + getMonth(days) - 1 + months
+    val (year, monthInYear, dayOfMonth, daysToMonthEnd) = splitDate(days)
+    val absoluteMonth = (year - YearZero) * 12 + monthInYear - 1 + months
     val nonNegativeMonth = if (absoluteMonth >= 0) absoluteMonth else 0
     val currentMonthInYear = nonNegativeMonth % 12
     val currentYear = nonNegativeMonth / 12
+
     val leapDay = if (currentMonthInYear == 1 && isLeapYear(currentYear + YearZero)) 1 else 0
     val lastDayOfMonth = monthDays(currentMonthInYear) + leapDay
 
-    val dayOfMonth = getDayOfMonth(days)
-    val currentDayInMonth = if (getDayOfMonth(days + 1) == 1 || dayOfMonth >= lastDayOfMonth) {
+    val currentDayInMonth = if (daysToMonthEnd == 0 || dayOfMonth >= lastDayOfMonth) {
       // last day of the month
       lastDayOfMonth
     } else {
@@ -640,46 +685,6 @@ object DateTimeUtils {
     daysToMillis(newDays) * 1000L + start - daysToMillis(days) * 1000L + microseconds
   }
 
-  /**
-   * Returns the last dayInMonth in the month it belongs to. The date is expressed
-   * in days since 1.1.1970. the return value starts from 1.
-   */
-  private def getLastDayInMonthOfMonth(date: Int): Int = {
-    var (year, dayInYear) = getYearAndDayInYear(date)
-    if (isLeapYear(year)) {
-      if (dayInYear > 31 && dayInYear <= 60) {
-        return 29
-      } else if (dayInYear > 60) {
-        dayInYear = dayInYear - 1
-      }
-    }
-    if (dayInYear <= 31) {
-      31
-    } else if (dayInYear <= 59) {
-      28
-    } else if (dayInYear <= 90) {
-      31
-    } else if (dayInYear <= 120) {
-      30
-    } else if (dayInYear <= 151) {
-      31
-    } else if (dayInYear <= 181) {
-      30
-    } else if (dayInYear <= 212) {
-      31
-    } else if (dayInYear <= 243) {
-      31
-    } else if (dayInYear <= 273) {
-      30
-    } else if (dayInYear <= 304) {
-      31
-    } else if (dayInYear <= 334) {
-      30
-    } else {
-      31
-    }
-  }
-
   /**
    * Returns number of months between time1 and time2. time1 and time2 are expressed in
    * microseconds since 1.1.1970.
@@ -695,14 +700,13 @@ object DateTimeUtils {
     val millis2 = time2 / 1000L
     val date1 = millisToDays(millis1)
     val date2 = millisToDays(millis2)
-    // TODO(davies): get year, month, dayOfMonth from single function
-    val dayInMonth1 = getDayOfMonth(date1)
-    val dayInMonth2 = getDayOfMonth(date2)
-    val months1 = getYear(date1) * 12 + getMonth(date1)
-    val months2 = getYear(date2) * 12 + getMonth(date2)
-
-    if (dayInMonth1 == dayInMonth2 || (dayInMonth1 == getLastDayInMonthOfMonth(date1)
-      && dayInMonth2 == getLastDayInMonthOfMonth(date2))) {
+    val (year1, monthInYear1, dayInMonth1, daysToMonthEnd1) = splitDate(date1)
+    val (year2, monthInYear2, dayInMonth2, daysToMonthEnd2) = splitDate(date2)
+
+    val months1 = year1 * 12 + monthInYear1
+    val months2 = year2 * 12 + monthInYear2
+
+    if (dayInMonth1 == dayInMonth2 || ((daysToMonthEnd1 == 0) && (daysToMonthEnd2 == 0))) {
       return (months1 - months2).toDouble
     }
     // milliseconds is enough for 8 digits precision on the right side
@@ -745,40 +749,8 @@ object DateTimeUtils {
    * since 1.1.1970.
    */
   def getLastDayOfMonth(date: Int): Int = {
-    var (year, dayInYear) = getYearAndDayInYear(date)
-    if (isLeapYear(year)) {
-      if (dayInYear > 31 && dayInYear <= 60) {
-        return date + (60 - dayInYear)
-      } else if (dayInYear > 60) {
-        dayInYear = dayInYear - 1
-      }
-    }
-    val lastDayOfMonthInYear = if (dayInYear <= 31) {
-      31
-    } else if (dayInYear <= 59) {
-      59
-    } else if (dayInYear <= 90) {
-      90
-    } else if (dayInYear <= 120) {
-      120
-    } else if (dayInYear <= 151) {
-      151
-    } else if (dayInYear <= 181) {
-      181
-    } else if (dayInYear <= 212) {
-      212
-    } else if (dayInYear <= 243) {
-      243
-    } else if (dayInYear <= 273) {
-      273
-    } else if (dayInYear <= 304) {
-      304
-    } else if (dayInYear <= 334) {
-      334
-    } else {
-      365
-    }
-    date + (lastDayOfMonthInYear - dayInYear)
+    val (_, _, _, daysToMonthEnd) = splitDate(date)
+    date + daysToMonthEnd
   }
 
   private val TRUNC_TO_YEAR = 1