Skip to content
Snippets Groups Projects
lstm_model.py 820 B
Newer Older
  • Learn to ignore specific revisions
  • from common import *
    
    # Hyper-parameters
    LSTM_UNITS = 16
    
    # Model Definition
    def main(baseline=None):
        lstm_model = keras.Sequential()
        lstm_model.add(layers.Input(shape=(INPUT_TIME_STEP, INPUT_FEATURES_SIZE)))
        # lstm_model.add(layers.LSTM(LSTM_UNITS, stateful=False, return_sequences=True))
        # lstm_model.add(layers.LSTM(LSTM_UNITS, stateful=False, return_sequences=True))
        lstm_model.add(layers.LSTM(LSTM_UNITS, stateful=False, return_sequences=False))
        lstm_model.add(layers.Dense(SLEEP_STAGES))
        lstm_model.build()
        print(lstm_model.summary())
        lstm_history = compile_and_fit(model=lstm_model, window=wg, baseline=baseline)
        print("Finished training")
        lstm_model.save(".model/lstm")
        save_history(lstm_history.history, ".history/lstm.csv")
    
    if __name__ == "__main__":
        main()