Skip to content
Snippets Groups Projects
baseline_model.py 724 B
Newer Older
  • Learn to ignore specific revisions
  • from common import *
    
    
    def main():
        global baseline_model
        BASELINE_UNITS = 16
        baseline_model = keras.Sequential(
            [
                layers.Input(shape=(INPUT_TIME_STEP, INPUT_FEATURES_SIZE)),
                layers.Dense(BASELINE_UNITS),
                # layers.Dense(BASELINE_UNITS),
                # layers.Dense(BASELINE_UNITS),
                layers.Flatten(),
                layers.Dense(SLEEP_STAGES),
            ]
        )
        baseline_model.build()
        print(baseline_model.summary())
        baseline_history = compile_and_fit(baseline_model, wg)
        print("Finished training")
        save_history(baseline_history.history, ".history/baseline.csv")
        baseline_model.save(".model/baseline")
    
    if __name__ == "__main__":
        main()