from common import *


def main():
    global baseline_model
    BASELINE_UNITS = 16
    baseline_model = keras.Sequential(
        [
            layers.Input(shape=(INPUT_TIME_STEP, INPUT_FEATURES_SIZE)),
            layers.Dense(BASELINE_UNITS),
            # layers.Dense(BASELINE_UNITS),
            # layers.Dense(BASELINE_UNITS),
            layers.Flatten(),
            layers.Dense(SLEEP_STAGES),
        ]
    )
    baseline_model.build()
    print(baseline_model.summary())
    baseline_history = compile_and_fit(baseline_model, wg)
    print("Finished training")
    save_history(baseline_history.history, ".history/baseline.csv")
    baseline_model.save(".model/baseline")

if __name__ == "__main__":
    main()