Skip to content
Snippets Groups Projects
attention_model.py 640 B
from common import *
from attention import Attention

# Hyper-parameters
ATTENTION_UNITS = 16

# Model Definition
def main(baseline=None):
    am_model = keras.Sequential()
    am_model.add(layers.Input(shape=(INPUT_TIME_STEP, INPUT_FEATURES_SIZE)))
    am_model.add(Attention(ATTENTION_UNITS))
    am_model.add(layers.Dense(SLEEP_STAGES))
    am_model.build()
    print(am_model.summary())
    am_history = compile_and_fit(model=am_model, window=wg, baseline=baseline)
    print("Finished training")
    am_model.save(".model/attention")
    save_history(am_history.history, ".history/attention.csv")

if __name__ == "__main__":
    main()