from typing import Optional, Deque

import sys
from alarm import BaseAlarmClock, AlarmState
from datetime import time, datetime, timedelta, timezone
from time import sleep
# from pynput import keyboard
import keyboard as kb
from collections import deque
import numpy as np
# import tensorflow as tf
import tflite_runtime.interpreter as tflite
import vlc

# Lower threshold means it only needs to be somewhat probable to be early wake up
# >= 1 means never allow early wake up in that time
WAKE_THRESHOLD = np.array([0.5, 0.75, 0.9, 1.0])

# SNOOZE_KEY = keyboard.Key.up
SNOOZE_KEY = "space"
# ALARM_OFF_KEY = keyboard.Key.esc
ALARM_OFF_KEY = "esc"
MODEL_MEMORY: int = 5

SECONDS_IN_MINUTE: float = 60
SNOOZE_SEC: float = 60
WAIT_SEC: float = 1

MODEL_PATH: str = ".model/attention_model.tflite"
VLC_INSTANCE = vlc.Instance("--input-repeat=999")
VLC_PLAYER = VLC_INSTANCE.media_player_new()
SONG = VLC_INSTANCE.media_new(".model/Viennese Poets.mp3" )

# Input Array: [minutes_since_start, current_hour_utc, current_minute_utc, awake_prob, rem_prob, light_prob, deep_prob]
# Output Array: [awake_prob, rem_prob, light_prob, deep_prob]


def softmax(arr):
    val = np.exp(arr)
    return val / sum(val)


def model_prediction(model_func, current_times_prob: Deque[np.array]) -> np.array:
    # print(f"Prob shape: {np.array([current_times_prob]).shape}")
    if model_func:
        # prediction = model.predict(np.array([current_times_prob]))[0]
        prediction = model_func(x=np.array([current_times_prob]))
        print(prediction)
        # print(f"prediction: {prediction}")
        # print(f"softmax: {softmax(prediction)}")
        return softmax(prediction)
    else:
        return np.zeros(4)


class SmartAlarmClock(BaseAlarmClock):

    def __init__(self,
                 wake_time: Optional[datetime] = None,
                 earliest_wake: Optional[timedelta] = None,
                 wake_threshold: np.array = WAKE_THRESHOLD,
                 vlc_player=VLC_PLAYER,
                 default_song=SONG):
        super().__init__(wake_time)
        self.earliest_wake = self.wake_time
        if wake_time:
            self.set_alarm(wake_time, earliest_wake)
        self.wake_threshold = wake_threshold
        self._vlc_player = vlc_player
        self.song = default_song
        if self._vlc_player and self.song:
            self._vlc_player.set_media(self.song)

    def set_alarm(self, wake_time: datetime, earliest_wake: Optional[timedelta] = None) -> datetime:
        print(f"SETTING ALARM to: {wake_time}")
        returned_value = super(SmartAlarmClock, self).set_alarm(wake_time)
        self.earliest_wake = wake_time - earliest_wake if earliest_wake else self.wake_time
        print(returned_value)
        return returned_value

    def start_alarm(self) -> None:
        print("STARTING ALARM")
        super(SmartAlarmClock, self).start_alarm()

    def sound_alarm(self, override_song=None) -> None:
        print("SOUNDING ALARM")
        if override_song:
            self.song = override_song
        if self._vlc_player:
            self._vlc_player.set_media(self.song)
            self._vlc_player.play()
        super(SmartAlarmClock, self).sound_alarm()

    def snooze_alarm(self) -> None:
        print("SNOOZING ALARM")
        super(SmartAlarmClock, self).snooze_alarm()
        if self._vlc_player:
            self._vlc_player.pause()
        sleep(SNOOZE_SEC)
        if self._vlc_player:
            self._vlc_player.play()
        print("UN-SNOOZING ALARM")
        self.sound_alarm()

    def stop_alarm(self, deactivate: bool = True) -> None:
        print("STOPPING ALARM")
        if self._vlc_player:
            self._vlc_player.stop()
            print("STOPPED PLAYING ALARM")
        super(SmartAlarmClock, self).stop_alarm()

    def check_early_alarm_reached(self, current_time: datetime, prior_prediction: np.array) -> bool:
        if current_time < self.earliest_wake:
            return False
        return np.max(prior_prediction - self.wake_threshold) > 0

    def smart_alarm_mode(self, fn, model_memory=MODEL_MEMORY):
        print("STARTING IN SMART MODE")
        kb.on_press_key("c", lambda x: sys.exit())
        kb.on_press_key(SNOOZE_KEY, lambda x: self.snooze_alarm())
        kb.on_press_key(ALARM_OFF_KEY, lambda x: self.stop_alarm())
        self.start_alarm()
        time_queue = deque(maxlen=model_memory)
        # While not filled time_queue keep appending awake data (warm-up machine)
        minutes_since_start = 0
        while len(time_queue) < time_queue.maxlen:
            current_utc = datetime.utcnow()
            time_queue.append(np.array(
                [minutes_since_start, current_utc.hour, current_utc.minute, 1.0, 0.0, 0.0, 0.0]
            ))
            print(f"time_queue: {time_queue}")
            # Wait a minute for next timestep
            while (datetime.utcnow() - current_utc).seconds < SECONDS_IN_MINUTE:
                sleep(WAIT_SEC)
            minutes_since_start += 1

        # While the alarm clock is running
        last_checked = datetime.utcnow()
        last_prediction = model_prediction(fn, time_queue)
        print(f"{last_checked}: last_prediction = {last_prediction}")
        print(f"check early: {self.check_early_alarm_reached(datetime.now(tz=timezone.utc), last_prediction)}")
        while self.current_state is AlarmState.RUNNING \
                and not self.check_early_alarm_reached(datetime.now(tz=timezone.utc), last_prediction) \
                and not self.alarm_check_reached(datetime.now(tz=timezone.utc)):
            print(f"ALARM SLEEPING @: {datetime.now()}")
            if (datetime.utcnow() - last_checked).seconds < SECONDS_IN_MINUTE:
                last_checked = datetime.utcnow()
                time_queue.append(
                    np.concatenate(
                        (
                            [minutes_since_start, last_checked.hour, last_checked.minute],
                            last_prediction
                        )
                    )
                )
                print(f"time_queue = {time_queue}")
            last_prediction = model_prediction(fn, time_queue)
            print(f"{datetime.utcnow()}: last_prediction = {last_prediction}")
            sleep(WAIT_SEC)

        # While the alarm clock is sounding off
        self.sound_alarm()
        while self.current_state is not AlarmState.DEACTIVATED:
            pass
            # event = kb.read_event()
            # if event.event_type == kb.KEY_DOWN:
            # # with keyboard.Events() as events:
            # #     for event in events:
            # #         if event.key == ALARM_OFF_KEY:
            #         if event.name == ALARM_OFF_KEY:
            #             self.stop_alarm()
            #             break
            #         elif event.name == SNOOZE_KEY:
            #         # elif event.key == SNOOZE_KEY:
            #             self.snooze_alarm()
            #             break


if __name__ == '__main__':
    alarm_hour = int(input("What hour do you want the alarm to go off at? "))
    alarm_minute = int(input("What minute do you want the alarm to go off at? "))
    early_minutes = int(input("How many minute early would you have the alarm go off at? "))
    run_alarm_today = map(
        lambda response: response.lower() in ["y", "yes", "t", "true"],
        input("Will the alarm run today? [y/N]")
    )
    alarm_time = datetime.combine(
        datetime.today() if run_alarm_today else (datetime.today() + timedelta(days=1)).date(),
        time(alarm_hour, alarm_minute, 0)
    ).astimezone(tz=timezone.utc)
    early_time = alarm_time - timedelta(minutes=early_minutes)
    alarm_clock_class = SmartAlarmClock(alarm_time, earliest_wake=timedelta(minutes=early_minutes))
    print(f"Loading model from: {MODEL_PATH}")
    ai_model = tflite.Interpreter(model_path=MODEL_PATH)
    ai_model.allocate_tensors()  # tf.keras.models.load_model(MODEL_PATH)
    model_fn = ai_model.get_signature_runner()
    print(f"Initialized Model")
    alarm_clock_class.smart_alarm_mode(model_fn)