from common import *

# Hyper-parameters
LSTM_UNITS = 16

# Model Definition
def main(baseline=None):
    lstm_model = keras.Sequential()
    lstm_model.add(layers.Input(shape=(INPUT_TIME_STEP, INPUT_FEATURES_SIZE)))
    # lstm_model.add(layers.LSTM(LSTM_UNITS, stateful=False, return_sequences=True))
    # lstm_model.add(layers.LSTM(LSTM_UNITS, stateful=False, return_sequences=True))
    lstm_model.add(layers.LSTM(LSTM_UNITS, stateful=False, return_sequences=False))
    lstm_model.add(layers.Dense(SLEEP_STAGES))
    lstm_model.build()
    print(lstm_model.summary())
    lstm_history = compile_and_fit(model=lstm_model, window=wg, baseline=baseline)
    print("Finished training")
    lstm_model.save(".model/lstm")
    save_history(lstm_history.history, ".history/lstm.csv")

if __name__ == "__main__":
    main()