diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index f645eb5f7bb01b5ae446a02b6fb215a55453a106..063940cb9e2c3b58f1441c5a3c8c54dd56bc4d83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -493,6 +493,50 @@ object DateTimeUtils { } } + /** + * Split date (expressed in days since 1.1.1970) into four fields: + * year, month (Jan is Month 1), dayInMonth, daysToMonthEnd (0 if it's last day of month). + */ + def splitDate(date: Int): (Int, Int, Int, Int) = { + var (year, dayInYear) = getYearAndDayInYear(date) + val isLeap = isLeapYear(year) + if (isLeap && dayInYear == 60) { + (year, 2, 29, 0) + } else { + if (isLeap && dayInYear > 60) dayInYear -= 1 + + if (dayInYear <= 181) { + if (dayInYear <= 31) { + (year, 1, dayInYear, 31 - dayInYear) + } else if (dayInYear <= 59) { + (year, 2, dayInYear - 31, if (isLeap) 60 - dayInYear else 59 - dayInYear) + } else if (dayInYear <= 90) { + (year, 3, dayInYear - 59, 90 - dayInYear) + } else if (dayInYear <= 120) { + (year, 4, dayInYear - 90, 120 - dayInYear) + } else if (dayInYear <= 151) { + (year, 5, dayInYear - 120, 151 - dayInYear) + } else { + (year, 6, dayInYear - 151, 181 - dayInYear) + } + } else { + if (dayInYear <= 212) { + (year, 7, dayInYear - 181, 212 - dayInYear) + } else if (dayInYear <= 243) { + (year, 8, dayInYear - 212, 243 - dayInYear) + } else if (dayInYear <= 273) { + (year, 9, dayInYear - 243, 273 - dayInYear) + } else if (dayInYear <= 304) { + (year, 10, dayInYear - 273, 304 - dayInYear) + } else if (dayInYear <= 334) { + (year, 11, dayInYear - 304, 334 - dayInYear) + } else { + (year, 12, dayInYear - 334, 365 - dayInYear) + } + } + } + } + /** * Returns the month value for the given date. The date is expressed in days * since 1.1.1970. January is month 1. @@ -613,15 +657,16 @@ object DateTimeUtils { * Returns a date value, expressed in days since 1.1.1970. */ def dateAddMonths(days: Int, months: Int): Int = { - val absoluteMonth = (getYear(days) - YearZero) * 12 + getMonth(days) - 1 + months + val (year, monthInYear, dayOfMonth, daysToMonthEnd) = splitDate(days) + val absoluteMonth = (year - YearZero) * 12 + monthInYear - 1 + months val nonNegativeMonth = if (absoluteMonth >= 0) absoluteMonth else 0 val currentMonthInYear = nonNegativeMonth % 12 val currentYear = nonNegativeMonth / 12 + val leapDay = if (currentMonthInYear == 1 && isLeapYear(currentYear + YearZero)) 1 else 0 val lastDayOfMonth = monthDays(currentMonthInYear) + leapDay - val dayOfMonth = getDayOfMonth(days) - val currentDayInMonth = if (getDayOfMonth(days + 1) == 1 || dayOfMonth >= lastDayOfMonth) { + val currentDayInMonth = if (daysToMonthEnd == 0 || dayOfMonth >= lastDayOfMonth) { // last day of the month lastDayOfMonth } else { @@ -640,46 +685,6 @@ object DateTimeUtils { daysToMillis(newDays) * 1000L + start - daysToMillis(days) * 1000L + microseconds } - /** - * Returns the last dayInMonth in the month it belongs to. The date is expressed - * in days since 1.1.1970. the return value starts from 1. - */ - private def getLastDayInMonthOfMonth(date: Int): Int = { - var (year, dayInYear) = getYearAndDayInYear(date) - if (isLeapYear(year)) { - if (dayInYear > 31 && dayInYear <= 60) { - return 29 - } else if (dayInYear > 60) { - dayInYear = dayInYear - 1 - } - } - if (dayInYear <= 31) { - 31 - } else if (dayInYear <= 59) { - 28 - } else if (dayInYear <= 90) { - 31 - } else if (dayInYear <= 120) { - 30 - } else if (dayInYear <= 151) { - 31 - } else if (dayInYear <= 181) { - 30 - } else if (dayInYear <= 212) { - 31 - } else if (dayInYear <= 243) { - 31 - } else if (dayInYear <= 273) { - 30 - } else if (dayInYear <= 304) { - 31 - } else if (dayInYear <= 334) { - 30 - } else { - 31 - } - } - /** * Returns number of months between time1 and time2. time1 and time2 are expressed in * microseconds since 1.1.1970. @@ -695,14 +700,13 @@ object DateTimeUtils { val millis2 = time2 / 1000L val date1 = millisToDays(millis1) val date2 = millisToDays(millis2) - // TODO(davies): get year, month, dayOfMonth from single function - val dayInMonth1 = getDayOfMonth(date1) - val dayInMonth2 = getDayOfMonth(date2) - val months1 = getYear(date1) * 12 + getMonth(date1) - val months2 = getYear(date2) * 12 + getMonth(date2) - - if (dayInMonth1 == dayInMonth2 || (dayInMonth1 == getLastDayInMonthOfMonth(date1) - && dayInMonth2 == getLastDayInMonthOfMonth(date2))) { + val (year1, monthInYear1, dayInMonth1, daysToMonthEnd1) = splitDate(date1) + val (year2, monthInYear2, dayInMonth2, daysToMonthEnd2) = splitDate(date2) + + val months1 = year1 * 12 + monthInYear1 + val months2 = year2 * 12 + monthInYear2 + + if (dayInMonth1 == dayInMonth2 || ((daysToMonthEnd1 == 0) && (daysToMonthEnd2 == 0))) { return (months1 - months2).toDouble } // milliseconds is enough for 8 digits precision on the right side @@ -745,40 +749,8 @@ object DateTimeUtils { * since 1.1.1970. */ def getLastDayOfMonth(date: Int): Int = { - var (year, dayInYear) = getYearAndDayInYear(date) - if (isLeapYear(year)) { - if (dayInYear > 31 && dayInYear <= 60) { - return date + (60 - dayInYear) - } else if (dayInYear > 60) { - dayInYear = dayInYear - 1 - } - } - val lastDayOfMonthInYear = if (dayInYear <= 31) { - 31 - } else if (dayInYear <= 59) { - 59 - } else if (dayInYear <= 90) { - 90 - } else if (dayInYear <= 120) { - 120 - } else if (dayInYear <= 151) { - 151 - } else if (dayInYear <= 181) { - 181 - } else if (dayInYear <= 212) { - 212 - } else if (dayInYear <= 243) { - 243 - } else if (dayInYear <= 273) { - 273 - } else if (dayInYear <= 304) { - 304 - } else if (dayInYear <= 334) { - 334 - } else { - 365 - } - date + (lastDayOfMonthInYear - dayInYear) + val (_, _, _, daysToMonthEnd) = splitDate(date) + date + daysToMonthEnd } private val TRUNC_TO_YEAR = 1