{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Notebook for training and testing AI models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import numpy as np\n", "from tensorflow import keras\n", "from tensorflow.keras import layers\n", "import pandas as pd\n", "from attention import CustomAttention\n", "import json" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.8.0\n" ] } ], "source": [ "print(tf.__version__)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# CONSTANTS\n", "RAW_SLEEP_DATA_PATH = \".data/raw_bed_sleep-state.csv\"\n", "CLEANED_SLEEP_DATA_PATH = \".data/clean_bed_sleep-state.csv\"\n", "SLEEP_DATA_PATH = \".data/sleep_data_simple.csv\"\n", "UPDATED_SLEEP_DATA_PATH = \".data/updated_sleep_data.csv\"" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "## Parameters and Hyper-parameters\n", "SLEEP_STAGES = 4" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Cleaning Raw Data" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "import csv\n", "import datetime\n", "import itertools" ] }, { "cell_type": "code", "execution_count": 104, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "datetime.datetime(2022, 4, 21, 10, 19, tzinfo=datetime.timezone(datetime.timedelta(seconds=7200)))" ] }, "execution_count": 104, "metadata": {}, "output_type": "execute_result" } ], "source": [ "datetime.datetime.strptime(\"2022-04-21T10:18:00+02:00\",\"%Y-%m-%dT%H:%M:%S%z\") + datetime.timedelta(minutes=1)" ] }, { "cell_type": "code", "execution_count": 115, "metadata": {}, "outputs": [], "source": [ "def stage_probability(stage, to_test):\n", " return 1.0 if stage == to_test else 0.0" ] }, { "cell_type": "code", "execution_count": 150, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "6.833333333333333" ] }, "execution_count": 150, "metadata": {}, "output_type": "execute_result" } ], "source": [ "((start_time - in_bed_time).seconds)/3600" ] }, { "cell_type": "code", "execution_count": 201, "metadata": {}, "outputs": [], "source": [ "cleaned_info = []\n", "date_seen = set()\n", "previous_duration = 60\n", "with open(RAW_SLEEP_DATA_PATH, mode ='r') as raw_file:\n", " csvFile = csv.reader(raw_file)\n", " # max_count = 1\n", " # stuff = set()\n", " in_bed_time = None\n", " current_sleep_id = -1\n", " for index, lines in enumerate(csvFile):\n", " if index == 0:\n", " cleaned_info.append([\n", " \"sleep_id\",\n", " \"sleep_begin\",\n", " \"stage_start\",\n", " \"time_since_begin_sec\",\n", " \"stage_duration_sec\",\n", " \"stage_end\", \n", " \"stage_value\",\n", " \"awake_probability\",\n", " \"light_probability\",\n", " \"deep_probability\",\n", " \"rem_probability\",\n", " ])\n", " continue\n", " start_time = datetime.datetime.strptime(lines[0],\"%Y-%m-%dT%H:%M:%S%z\")\n", " if start_time in date_seen:\n", " continue\n", " date_seen.add(start_time)\n", " if not in_bed_time or in_bed_time > start_time:\n", " current_sleep_id += 1\n", " in_bed_time = start_time\n", " # for duration, stage in enumerate(\n", " # for offset, (duration, stage) in enumerate(\n", " # zip(\n", " # # itertools.accumulate(lines[1].strip(\"[]\").split(\",\"), lambda x,y: int(x)+int(y)//60, initial = 0), \n", " # map(int, lines[1].strip(\"[]\").split(\",\"))\n", " # map(int, lines[2].strip(\"[]\").split(\",\"))\n", " # )\n", " # # map(int, lines[2].strip(\"[]\").split(\",\"))\n", " # ):\n", " for offset, (duration, stage) in enumerate(zip(map(int, lines[1].strip(\"[]\").split(\",\")), map(int, lines[2].strip(\"[]\").split(\",\")))):\n", " # print(f\"{(index, subindex) = }, {duration = }, {stage = }\")\n", " # print(f\"{(index, duration) = } {stage = }\")\n", " current_time = start_time + datetime.timedelta(seconds=offset*previous_duration)\n", " cleaned_info.append([\n", " current_sleep_id,\n", " in_bed_time,\n", " current_time, \n", " (current_time - in_bed_time).seconds,\n", " duration, \n", " current_time + datetime.timedelta(seconds=duration), \n", " stage,\n", " stage_probability(0, stage),\n", " stage_probability(1, stage),\n", " stage_probability(2, stage),\n", " stage_probability(3, stage),\n", " ])\n", " previous_duration = duration\n", " # print(f\"{(index, subindex) = }, {val = }\")\n", " # print(list())\n", " # if index >= max_count:\n", " # break\n" ] }, { "cell_type": "code", "execution_count": 202, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Finished Writing Cleaned Data\n" ] } ], "source": [ "with open(CLEANED_SLEEP_DATA_PATH, 'w') as clean_file:\n", " write = csv.writer(clean_file)\n", " write.writerows(cleaned_info)\n", "print(\"Finished Writing Cleaned Data\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Creating DataFrame from clean raw data" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [], "source": [ "# Get the cleaned data\n", "sleep_df_raw = pd.read_csv(CLEANED_SLEEP_DATA_PATH)#, parse_dates=[\"start\", \"end\"], infer_datetime_format=True)" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [], "source": [ "# Preprocess data: \n", "# 1. convert to datetime\n", "sleep_df_raw[\"sleep_begin\"] = pd.to_datetime(sleep_df_raw[\"sleep_begin\"], utc=True)\n", "sleep_df_raw[\"stage_start\"] = pd.to_datetime(sleep_df_raw[\"stage_start\"], utc=True)\n", "sleep_df_raw[\"stage_end\"] = pd.to_datetime(sleep_df_raw[\"stage_end\"], utc=True)\n", "# 2. Separate time, hour and minute\n", "# MAYBE 3. smaller units: int16 or int8 " ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [], "source": [ "def get_minute(row, index):\n", " return row[index].time().minute\n", "\n", "def get_hour(row, index):\n", " return row[index].time().hour" ] }, { "cell_type": "code", "execution_count": 82, "metadata": {}, "outputs": [], "source": [ "sleep_df_raw[\"stage_start_hour\"] = sleep_df_raw.apply (lambda row: get_hour(row, \"stage_start\"), axis=1)\n", "sleep_df_raw[\"stage_start_minute\"] = sleep_df_raw.apply (lambda row: get_minute(row, \"stage_start\"), axis=1)" ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<class 'pandas.core.frame.DataFrame'>\n", "RangeIndex: 551042 entries, 0 to 551041\n", "Data columns (total 13 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 sleep_id 551042 non-null int64 \n", " 1 sleep_begin 551042 non-null datetime64[ns, UTC]\n", " 2 stage_start 551042 non-null datetime64[ns, UTC]\n", " 3 time_since_begin_sec 551042 non-null int64 \n", " 4 stage_duration_sec 551042 non-null int64 \n", " 5 stage_end 551042 non-null datetime64[ns, UTC]\n", " 6 stage_value 551042 non-null int64 \n", " 7 awake_probability 551042 non-null float64 \n", " 8 light_probability 551042 non-null float64 \n", " 9 deep_probability 551042 non-null float64 \n", " 10 rem_probability 551042 non-null float64 \n", " 11 stage_start_hour 551042 non-null int64 \n", " 12 stage_start_minute 551042 non-null int64 \n", "dtypes: datetime64[ns, UTC](3), float64(4), int64(6)\n", "memory usage: 54.7 MB\n" ] } ], "source": [ "sleep_df_raw.info()" ] }, { "cell_type": "code", "execution_count": 84, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>sleep_id</th>\n", " <th>sleep_begin</th>\n", " <th>stage_start</th>\n", " <th>time_since_begin_sec</th>\n", " <th>stage_duration_sec</th>\n", " <th>stage_end</th>\n", " <th>stage_value</th>\n", " <th>awake_probability</th>\n", " <th>light_probability</th>\n", " <th>deep_probability</th>\n", " <th>rem_probability</th>\n", " <th>stage_start_hour</th>\n", " <th>stage_start_minute</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0</td>\n", " <td>2022-04-21 08:18:00+00:00</td>\n", " <td>2022-04-21 08:18:00+00:00</td>\n", " <td>0</td>\n", " <td>60</td>\n", " <td>2022-04-21 08:19:00+00:00</td>\n", " <td>0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>8</td>\n", " <td>18</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>0</td>\n", " <td>2022-04-21 08:18:00+00:00</td>\n", " <td>2022-04-21 08:19:00+00:00</td>\n", " <td>60</td>\n", " <td>60</td>\n", " <td>2022-04-21 08:20:00+00:00</td>\n", " <td>0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>8</td>\n", " <td>19</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>0</td>\n", " <td>2022-04-21 08:18:00+00:00</td>\n", " <td>2022-04-21 08:20:00+00:00</td>\n", " <td>120</td>\n", " <td>60</td>\n", " <td>2022-04-21 08:21:00+00:00</td>\n", " <td>0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>8</td>\n", " <td>20</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>0</td>\n", " <td>2022-04-21 08:18:00+00:00</td>\n", " <td>2022-04-21 08:21:00+00:00</td>\n", " <td>180</td>\n", " <td>60</td>\n", " <td>2022-04-21 08:22:00+00:00</td>\n", " <td>0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>8</td>\n", " <td>21</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>0</td>\n", " <td>2022-04-21 08:18:00+00:00</td>\n", " <td>2022-04-21 08:22:00+00:00</td>\n", " <td>240</td>\n", " <td>60</td>\n", " <td>2022-04-21 08:23:00+00:00</td>\n", " <td>0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>8</td>\n", " <td>22</td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>551037</th>\n", " <td>1132</td>\n", " <td>2019-02-11 06:11:00+00:00</td>\n", " <td>2019-02-11 13:17:00+00:00</td>\n", " <td>25560</td>\n", " <td>60</td>\n", " <td>2019-02-11 13:18:00+00:00</td>\n", " <td>1</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>13</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>551038</th>\n", " <td>1132</td>\n", " <td>2019-02-11 06:11:00+00:00</td>\n", " <td>2019-02-11 13:18:00+00:00</td>\n", " <td>25620</td>\n", " <td>60</td>\n", " <td>2019-02-11 13:19:00+00:00</td>\n", " <td>1</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>13</td>\n", " <td>18</td>\n", " </tr>\n", " <tr>\n", " <th>551039</th>\n", " <td>1132</td>\n", " <td>2019-02-11 06:11:00+00:00</td>\n", " <td>2019-02-11 13:19:00+00:00</td>\n", " <td>25680</td>\n", " <td>60</td>\n", " <td>2019-02-11 13:20:00+00:00</td>\n", " <td>1</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>13</td>\n", " <td>19</td>\n", " </tr>\n", " <tr>\n", " <th>551040</th>\n", " <td>1132</td>\n", " <td>2019-02-11 06:11:00+00:00</td>\n", " <td>2019-02-11 13:20:00+00:00</td>\n", " <td>25740</td>\n", " <td>60</td>\n", " <td>2019-02-11 13:21:00+00:00</td>\n", " <td>1</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>13</td>\n", " <td>20</td>\n", " </tr>\n", " <tr>\n", " <th>551041</th>\n", " <td>1132</td>\n", " <td>2019-02-11 06:11:00+00:00</td>\n", " <td>2019-02-11 13:21:00+00:00</td>\n", " <td>25800</td>\n", " <td>60</td>\n", " <td>2019-02-11 13:22:00+00:00</td>\n", " <td>1</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>13</td>\n", " <td>21</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>551042 rows × 13 columns</p>\n", "</div>" ], "text/plain": [ " sleep_id sleep_begin stage_start \\\n", "0 0 2022-04-21 08:18:00+00:00 2022-04-21 08:18:00+00:00 \n", "1 0 2022-04-21 08:18:00+00:00 2022-04-21 08:19:00+00:00 \n", "2 0 2022-04-21 08:18:00+00:00 2022-04-21 08:20:00+00:00 \n", "3 0 2022-04-21 08:18:00+00:00 2022-04-21 08:21:00+00:00 \n", "4 0 2022-04-21 08:18:00+00:00 2022-04-21 08:22:00+00:00 \n", "... ... ... ... \n", "551037 1132 2019-02-11 06:11:00+00:00 2019-02-11 13:17:00+00:00 \n", "551038 1132 2019-02-11 06:11:00+00:00 2019-02-11 13:18:00+00:00 \n", "551039 1132 2019-02-11 06:11:00+00:00 2019-02-11 13:19:00+00:00 \n", "551040 1132 2019-02-11 06:11:00+00:00 2019-02-11 13:20:00+00:00 \n", "551041 1132 2019-02-11 06:11:00+00:00 2019-02-11 13:21:00+00:00 \n", "\n", " time_since_begin_sec stage_duration_sec stage_end \\\n", "0 0 60 2022-04-21 08:19:00+00:00 \n", "1 60 60 2022-04-21 08:20:00+00:00 \n", "2 120 60 2022-04-21 08:21:00+00:00 \n", "3 180 60 2022-04-21 08:22:00+00:00 \n", "4 240 60 2022-04-21 08:23:00+00:00 \n", "... ... ... ... \n", "551037 25560 60 2019-02-11 13:18:00+00:00 \n", "551038 25620 60 2019-02-11 13:19:00+00:00 \n", "551039 25680 60 2019-02-11 13:20:00+00:00 \n", "551040 25740 60 2019-02-11 13:21:00+00:00 \n", "551041 25800 60 2019-02-11 13:22:00+00:00 \n", "\n", " stage_value awake_probability light_probability deep_probability \\\n", "0 0 1.0 0.0 0.0 \n", "1 0 1.0 0.0 0.0 \n", "2 0 1.0 0.0 0.0 \n", "3 0 1.0 0.0 0.0 \n", "4 0 1.0 0.0 0.0 \n", "... ... ... ... ... \n", "551037 1 0.0 1.0 0.0 \n", "551038 1 0.0 1.0 0.0 \n", "551039 1 0.0 1.0 0.0 \n", "551040 1 0.0 1.0 0.0 \n", "551041 1 0.0 1.0 0.0 \n", "\n", " rem_probability stage_start_hour stage_start_minute \n", "0 0.0 8 18 \n", "1 0.0 8 19 \n", "2 0.0 8 20 \n", "3 0.0 8 21 \n", "4 0.0 8 22 \n", "... ... ... ... \n", "551037 0.0 13 17 \n", "551038 0.0 13 18 \n", "551039 0.0 13 19 \n", "551040 0.0 13 20 \n", "551041 0.0 13 21 \n", "\n", "[551042 rows x 13 columns]" ] }, "execution_count": 84, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sleep_df_raw" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "sleep_data = sleep_df_raw[[\"sleep_id\", \"stage_start_hour\", \"stage_start_minute\", \"awake_probability\", \"rem_probability\",\"light_probability\", \"deep_probability\"]]\n", "sleep_data.insert(loc=1, column=\"minutes_since_begin\" , value= sleep_df_raw[\"time_since_begin_sec\"]//60)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " sleep_id minutes_since_begin stage_start_hour stage_start_minute \\\n", "0 0 0 8 18 \n", "1 0 1 8 19 \n", "2 0 2 8 20 \n", "3 0 3 8 21 \n", "4 0 4 8 22 \n", "\n", " awake_probability rem_probability light_probability deep_probability \n", "0 1.0 0.0 0.0 0.0 \n", "1 1.0 0.0 0.0 0.0 \n", "2 1.0 0.0 0.0 0.0 \n", "3 1.0 0.0 0.0 0.0 \n", "4 1.0 0.0 0.0 0.0 \n", "<class 'pandas.core.frame.DataFrame'>\n", "RangeIndex: 551042 entries, 0 to 551041\n", "Data columns (total 8 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 sleep_id 551042 non-null int64 \n", " 1 minutes_since_begin 551042 non-null int64 \n", " 2 stage_start_hour 551042 non-null int64 \n", " 3 stage_start_minute 551042 non-null int64 \n", " 4 awake_probability 551042 non-null float64\n", " 5 rem_probability 551042 non-null float64\n", " 6 light_probability 551042 non-null float64\n", " 7 deep_probability 551042 non-null float64\n", "dtypes: float64(4), int64(4)\n", "memory usage: 33.6 MB\n", "None\n" ] } ], "source": [ "print(sleep_data.head())\n", "print(sleep_data.info())" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "sleep_data.to_csv(\".data/sleep_data_simple.csv\", index=False, index_label=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Development" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "TEST_SIZE = 365//2\n", "VALIDATION_SIZE = 365\n", "\n", "BATCH_SIZE = 64\n", "INPUT_TIME_STEP = 5 # in minutes\n", "INPUT_FEATURES_SIZE = 7\n", "MAX_EPOCHS = 20" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "SAMPLE_COUNT = 1000\n", "# A R L D\n", "# A\n", "# R\n", "# L\n", "# D\n", "CONFUSION_MATRIX = np.array(\n", " [\n", " [66.2, 5.0, 22.5, 6.2],\n", " [1.6, 60.7, 33.0, 4.7],\n", " [3.8, 22.3, 55.4, 18.5],\n", " [0.0, 1.3, 26.7, 72.0],\n", " ]\n", ")/100" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Import Data" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# sleep_data = pd.read_csv(SLEEP_DATA_PATH)\n", "sleep_data = pd.read_csv(UPDATED_SLEEP_DATA_PATH)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>sleep_id</th>\n", " <th>minutes_since_begin</th>\n", " <th>stage_start_hour</th>\n", " <th>stage_start_minute</th>\n", " <th>awake_probability_noisy</th>\n", " <th>rem_probability_noisy</th>\n", " <th>light_probability_noisy</th>\n", " <th>deep_probability_noisy</th>\n", " <th>awake_probability_original</th>\n", " <th>rem_probability_original</th>\n", " <th>light_probability_original</th>\n", " <th>deep_probability_original</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>8.0</td>\n", " <td>18.0</td>\n", " <td>0.680</td>\n", " <td>0.057</td>\n", " <td>0.200</td>\n", " <td>0.063</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>8.0</td>\n", " <td>19.0</td>\n", " <td>0.652</td>\n", " <td>0.060</td>\n", " <td>0.224</td>\n", " <td>0.064</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>0.0</td>\n", " <td>2.0</td>\n", " <td>8.0</td>\n", " <td>20.0</td>\n", " <td>0.672</td>\n", " <td>0.059</td>\n", " <td>0.209</td>\n", " <td>0.060</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>0.0</td>\n", " <td>3.0</td>\n", " <td>8.0</td>\n", " <td>21.0</td>\n", " <td>0.645</td>\n", " <td>0.056</td>\n", " <td>0.235</td>\n", " <td>0.064</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>0.0</td>\n", " <td>4.0</td>\n", " <td>8.0</td>\n", " <td>22.0</td>\n", " <td>0.644</td>\n", " <td>0.054</td>\n", " <td>0.244</td>\n", " <td>0.058</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>551037</th>\n", " <td>1132.0</td>\n", " <td>426.0</td>\n", " <td>13.0</td>\n", " <td>17.0</td>\n", " <td>0.041</td>\n", " <td>0.193</td>\n", " <td>0.576</td>\n", " <td>0.190</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>551038</th>\n", " <td>1132.0</td>\n", " <td>427.0</td>\n", " <td>13.0</td>\n", " <td>18.0</td>\n", " <td>0.027</td>\n", " <td>0.209</td>\n", " <td>0.563</td>\n", " <td>0.201</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>551039</th>\n", " <td>1132.0</td>\n", " <td>428.0</td>\n", " <td>13.0</td>\n", " <td>19.0</td>\n", " <td>0.032</td>\n", " <td>0.220</td>\n", " <td>0.574</td>\n", " <td>0.174</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>551040</th>\n", " <td>1132.0</td>\n", " <td>429.0</td>\n", " <td>13.0</td>\n", " <td>20.0</td>\n", " <td>0.036</td>\n", " <td>0.256</td>\n", " <td>0.530</td>\n", " <td>0.178</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>551041</th>\n", " <td>1132.0</td>\n", " <td>430.0</td>\n", " <td>13.0</td>\n", " <td>21.0</td>\n", " <td>0.033</td>\n", " <td>0.205</td>\n", " <td>0.571</td>\n", " <td>0.191</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>551042 rows × 12 columns</p>\n", "</div>" ], "text/plain": [ " sleep_id minutes_since_begin stage_start_hour stage_start_minute \\\n", "0 0.0 0.0 8.0 18.0 \n", "1 0.0 1.0 8.0 19.0 \n", "2 0.0 2.0 8.0 20.0 \n", "3 0.0 3.0 8.0 21.0 \n", "4 0.0 4.0 8.0 22.0 \n", "... ... ... ... ... \n", "551037 1132.0 426.0 13.0 17.0 \n", "551038 1132.0 427.0 13.0 18.0 \n", "551039 1132.0 428.0 13.0 19.0 \n", "551040 1132.0 429.0 13.0 20.0 \n", "551041 1132.0 430.0 13.0 21.0 \n", "\n", " awake_probability_noisy rem_probability_noisy \\\n", "0 0.680 0.057 \n", "1 0.652 0.060 \n", "2 0.672 0.059 \n", "3 0.645 0.056 \n", "4 0.644 0.054 \n", "... ... ... \n", "551037 0.041 0.193 \n", "551038 0.027 0.209 \n", "551039 0.032 0.220 \n", "551040 0.036 0.256 \n", "551041 0.033 0.205 \n", "\n", " light_probability_noisy deep_probability_noisy \\\n", "0 0.200 0.063 \n", "1 0.224 0.064 \n", "2 0.209 0.060 \n", "3 0.235 0.064 \n", "4 0.244 0.058 \n", "... ... ... \n", "551037 0.576 0.190 \n", "551038 0.563 0.201 \n", "551039 0.574 0.174 \n", "551040 0.530 0.178 \n", "551041 0.571 0.191 \n", "\n", " awake_probability_original rem_probability_original \\\n", "0 1.0 0.0 \n", "1 1.0 0.0 \n", "2 1.0 0.0 \n", "3 1.0 0.0 \n", "4 1.0 0.0 \n", "... ... ... \n", "551037 0.0 0.0 \n", "551038 0.0 0.0 \n", "551039 0.0 0.0 \n", "551040 0.0 0.0 \n", "551041 0.0 0.0 \n", "\n", " light_probability_original deep_probability_original \n", "0 0.0 0.0 \n", "1 0.0 0.0 \n", "2 0.0 0.0 \n", "3 0.0 0.0 \n", "4 0.0 0.0 \n", "... ... ... \n", "551037 1.0 0.0 \n", "551038 1.0 0.0 \n", "551039 1.0 0.0 \n", "551040 1.0 0.0 \n", "551041 1.0 0.0 \n", "\n", "[551042 rows x 12 columns]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sleep_data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create Randmonizations" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "def softmax(arr):\n", " val = np.exp(arr)\n", " return val / sum(val)" ] }, { "cell_type": "code", "execution_count": 133, "metadata": {}, "outputs": [], "source": [ "def noisy_randomizer(confusion_matrix_index: int, sample_count: int = SAMPLE_COUNT, confusion_matrix = CONFUSION_MATRIX):\n", " return np.random.multinomial(sample_count, confusion_matrix[confusion_matrix_index], size=1)[0]/sample_count" ] }, { "cell_type": "code", "execution_count": 173, "metadata": {}, "outputs": [], "source": [ "def create_updated(dataframe):\n", " dataframe_array = dataframe.to_numpy()\n", " updated_probs = np.array(list(map(noisy_randomizer, np.argmax(dataframe_array[:, -4:], axis=1))))\n", " columns_tmp = dataframe.columns.to_list() \n", " noisy_column = list(map(lambda name: f\"{name}_noisy\", columns_tmp[-4:]))\n", " og_column = list(map(lambda name: f\"{name}_original\", columns_tmp[-4:]))\n", " return pd.DataFrame(np.concatenate([dataframe_array[:, :-4], updated_probs, dataframe_array[:, -4:]], axis=1), columns = columns_tmp[:-4]+noisy_column+og_column)" ] }, { "cell_type": "code", "execution_count": 174, "metadata": {}, "outputs": [], "source": [ "sleep_data_updated = create_updated(sleep_data)" ] }, { "cell_type": "code", "execution_count": 176, "metadata": {}, "outputs": [], "source": [ "sleep_data_updated.to_csv(UPDATED_SLEEP_DATA_PATH, index=False, index_label=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Helper functions and class" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def training_test_split_by_unique_index(data, index: str, test_size: int = 10):\n", " test_ids = np.random.choice(data[index].unique(), size = test_size, replace=False)\n", " return data[~data[index].isin(test_ids)], data[data[index].isin(test_ids)]" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# Adapted from https://www.tensorflow.org/tutorials/structured_data/time_series\n", "class WindowGenerator():\n", " def __init__(self, data, index: str = \"sleep_id\", input_width: int = INPUT_TIME_STEP, validation_size: int = VALIDATION_SIZE, test_size: int = TEST_SIZE, input_feature_slice: slice = slice(1,-4), label_feature_slice: slice = slice(-4,100), generate_data_now: bool = True):\n", " # Partition data\n", " self.training, self.testing = training_test_split_by_unique_index(data, index, test_size)\n", " self.training, self.validation = training_test_split_by_unique_index(self.training, index, validation_size)\n", "\n", " # Window paramters\n", " self.input_width = input_width\n", " self.label_width = 1\n", " self.shift = 1\n", "\n", " self.total_window_size = self.input_width + self.shift\n", "\n", " self.input_slice = slice(0, input_width)\n", " self.input_indices = np.arange(self.total_window_size)[self.input_slice]\n", "\n", " self.label_start = self.total_window_size - self.label_width\n", " self.labels_slice = slice(self.label_start, None)\n", " self.label_indices = np.arange(self.total_window_size)[self.labels_slice]\n", "\n", " self.input_feature_slice = input_feature_slice\n", " self.label_feature_slice = label_feature_slice\n", "\n", " self.sample_ds = self.make_dataset(data[data[index] == 0])\n", "\n", " if generate_data_now:\n", " self.training_ds = self.make_dataset(self.training, index)\n", " self.validation_ds = self.make_dataset(self.validation, index)\n", " self.testing_ds = self.make_dataset(self.testing, index)\n", "\n", "\n", " def __repr__(self):\n", " return \"WindowGenerator:\\n\\t\" +'\\n\\t'.join([\n", " f'Total window size: {self.total_window_size}',\n", " f'Input indices: {self.input_indices}',\n", " f'Label indices: {self.label_indices}',\n", " ])\n", "\n", " def split_window(self, features):\n", " inputs = features[:, self.input_slice, self.input_feature_slice]\n", " inputs.set_shape([None, self.input_width, None])\n", " \n", " labels = tf.squeeze(features[:, self.labels_slice, self.label_feature_slice])\n", " # labels.set_shape([None, self.label_width, None])\n", " return inputs, labels\n", "\n", " def make_dataset(self, data, index_group: str = \"sleep_id\", sort_by: str = \"minutes_since_begin\"):\n", " ds_all = None\n", " for i_group in data[index_group].unique():\n", " subset_data = np.array(data[data[index_group] == i_group].sort_values(by=[sort_by]), dtype=np.float32)\n", " ds = tf.keras.utils.timeseries_dataset_from_array(\n", " data=subset_data,\n", " targets=None,\n", " sequence_length=self.total_window_size,\n", " sequence_stride=1,\n", " shuffle=False,\n", " batch_size=BATCH_SIZE,)\n", " ds_all = ds if ds_all is None else ds_all.concatenate(ds)\n", " ds_all = ds_all.map(self.split_window)\n", "\n", " return ds_all" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### General Model Helper" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# Adapted from https://www.tensorflow.org/tutorials/structured_data/time_series#linear_model\n", "def compile_and_fit(model, window: WindowGenerator, loss = tf.losses.CategoricalCrossentropy(from_logits=True), optimizer = tf.optimizers.Adam(), metrics = None, early_stop: bool = True, patience:int = 2, baseline = None, epochs: int = MAX_EPOCHS):\n", " if metrics is None:\n", " metrics = [tf.keras.metrics.CategoricalCrossentropy(from_logits=True), tf.keras.metrics.CategoricalAccuracy(), tf.keras.metrics.CategoricalHinge()]\n", "\n", " callbacks = []\n", " if early_stop:\n", " early_stopping = tf.keras.callbacks.EarlyStopping(\n", " monitor='val_loss',\n", " patience=patience,\n", " baseline = baseline,\n", " mode='min'\n", " )\n", " callbacks.append(early_stopping)\n", "\n", " model.compile(\n", " loss=loss,\n", " optimizer=optimizer,\n", " metrics=metrics,\n", " )\n", "\n", " return model.fit(window.training_ds, epochs=epochs, validation_data=window.validation_ds, callbacks=callbacks)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Experimenting" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# USE SUBSET OF DATA FOR EXPERIMENTING\n", "sleep_data_sub = sleep_data[sleep_data.sleep_id < 3]" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "wg_sub = WindowGenerator(sleep_data_sub,validation_size=1, test_size=1)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ele[0][0] = array([[ 0., 7., 40., 1., 0., 0., 0.],\n", " [ 1., 7., 41., 1., 0., 0., 0.],\n", " [ 2., 7., 42., 1., 0., 0., 0.],\n", " [ 3., 7., 43., 1., 0., 0., 0.],\n", " [ 4., 7., 44., 1., 0., 0., 0.],\n", " [ 5., 7., 45., 1., 0., 0., 0.],\n", " [ 6., 7., 46., 1., 0., 0., 0.],\n", " [ 7., 7., 47., 1., 0., 0., 0.],\n", " [ 8., 7., 48., 1., 0., 0., 0.],\n", " [ 9., 7., 49., 1., 0., 0., 0.]], dtype=float32)\n", "ele[0][0] = array([[64., 8., 44., 0., 0., 1., 0.],\n", " [65., 8., 45., 0., 0., 1., 0.],\n", " [66., 8., 46., 0., 0., 1., 0.],\n", " [67., 8., 47., 0., 0., 1., 0.],\n", " [68., 8., 48., 0., 0., 1., 0.],\n", " [69., 8., 49., 0., 0., 1., 0.],\n", " [70., 8., 50., 0., 0., 1., 0.],\n", " [71., 8., 51., 0., 0., 1., 0.],\n", " [72., 8., 52., 0., 0., 1., 0.],\n", " [73., 8., 53., 0., 0., 1., 0.]], dtype=float32)\n", "ele[0][0] = array([[128., 9., 48., 0., 0., 1., 0.],\n", " [129., 9., 49., 0., 0., 1., 0.],\n", " [130., 9., 50., 0., 0., 1., 0.],\n", " [131., 9., 51., 0., 0., 1., 0.],\n", " [132., 9., 52., 0., 0., 0., 1.],\n", " [133., 9., 53., 0., 0., 0., 1.],\n", " [134., 9., 54., 0., 0., 0., 1.],\n", " [135., 9., 55., 0., 0., 0., 1.],\n", " [136., 9., 56., 0., 0., 0., 1.],\n", " [137., 9., 57., 0., 0., 0., 1.]], dtype=float32)\n", "ele[0][0] = array([[192., 10., 52., 0., 1., 0., 0.],\n", " [193., 10., 53., 0., 1., 0., 0.],\n", " [194., 10., 54., 0., 1., 0., 0.],\n", " [195., 10., 55., 0., 1., 0., 0.],\n", " [196., 10., 56., 0., 1., 0., 0.],\n", " [197., 10., 57., 0., 1., 0., 0.],\n", " [198., 10., 58., 0., 1., 0., 0.],\n", " [199., 10., 59., 0., 1., 0., 0.],\n", " [200., 11., 0., 0., 1., 0., 0.],\n", " [201., 11., 1., 0., 1., 0., 0.]], dtype=float32)\n", "ele[0][0] = array([[256., 11., 56., 0., 0., 1., 0.],\n", " [257., 11., 57., 0., 0., 1., 0.],\n", " [258., 11., 58., 0., 0., 1., 0.],\n", " [259., 11., 59., 0., 0., 1., 0.],\n", " [260., 12., 0., 0., 0., 1., 0.],\n", " [261., 12., 1., 0., 0., 1., 0.],\n", " [262., 12., 2., 0., 0., 1., 0.],\n", " [263., 12., 3., 0., 0., 1., 0.],\n", " [264., 12., 4., 0., 0., 1., 0.],\n", " [265., 12., 5., 0., 0., 1., 0.]], dtype=float32)\n", "ele[0][0] = array([[320., 13., 0., 0., 1., 0., 0.],\n", " [321., 13., 1., 0., 1., 0., 0.],\n", " [322., 13., 2., 0., 1., 0., 0.],\n", " [323., 13., 3., 0., 1., 0., 0.],\n", " [324., 13., 4., 0., 1., 0., 0.],\n", " [325., 13., 5., 0., 1., 0., 0.],\n", " [326., 13., 6., 0., 1., 0., 0.],\n", " [327., 13., 7., 0., 1., 0., 0.],\n", " [328., 13., 8., 0., 1., 0., 0.],\n", " [329., 13., 9., 0., 1., 0., 0.]], dtype=float32)\n", "ele[0][0] = array([[384., 14., 4., 0., 0., 0., 1.],\n", " [385., 14., 5., 0., 0., 0., 1.],\n", " [386., 14., 6., 0., 0., 0., 1.],\n", " [387., 14., 7., 0., 0., 0., 1.],\n", " [388., 14., 8., 0., 0., 0., 1.],\n", " [389., 14., 9., 0., 0., 0., 1.],\n", " [390., 14., 10., 0., 0., 0., 1.],\n", " [391., 14., 11., 0., 0., 0., 1.],\n", " [392., 14., 12., 0., 0., 0., 1.],\n", " [393., 14., 13., 0., 0., 0., 1.]], dtype=float32)\n", "ele[0][0] = array([[448., 15., 8., 0., 1., 0., 0.],\n", " [449., 15., 9., 0., 1., 0., 0.],\n", " [450., 15., 10., 0., 1., 0., 0.],\n", " [451., 15., 11., 0., 1., 0., 0.],\n", " [452., 15., 12., 0., 1., 0., 0.],\n", " [453., 15., 13., 0., 1., 0., 0.],\n", " [454., 15., 14., 0., 1., 0., 0.],\n", " [455., 15., 15., 0., 1., 0., 0.],\n", " [456., 15., 16., 0., 1., 0., 0.],\n", " [457., 15., 17., 0., 1., 0., 0.]], dtype=float32)\n" ] } ], "source": [ "for ele in wg_sub.training_ds.as_numpy_iterator():\n", " print(f\"{ele[0][0] = }\")" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_2\"\n", "_________________________________________________________________\n", " Layer (type) Output Shape Param # \n", "=================================================================\n", " dense_4 (Dense) (None, 10, 16) 128 \n", " \n", " flatten_2 (Flatten) (None, 160) 0 \n", " \n", " dense_5 (Dense) (None, 4) 644 \n", " \n", "=================================================================\n", "Total params: 772\n", "Trainable params: 772\n", "Non-trainable params: 0\n", "_________________________________________________________________\n", "None\n" ] } ], "source": [ "BASELINE_UNITS = 16\n", "baseline_model = keras.Sequential(\n", " [\n", " layers.Input(shape=(INPUT_TIME_STEP, INPUT_FEATURES_SIZE)),\n", " layers.Dense(BASELINE_UNITS),\n", " # layers.Dense(BASELINE_UNITS),\n", " # layers.Dense(BASELINE_UNITS),\n", " layers.Flatten(),\n", " layers.Dense(SLEEP_STAGES),\n", " ]\n", ")\n", "baseline_model.build()\n", "print(baseline_model.summary())" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/20\n", "8955/8955 [==============================] - 60s 7ms/step - loss: 2.7762 - categorical_crossentropy: 2.7745 - categorical_accuracy: 0.7612 - categorical_hinge: 15.1798 - val_loss: 0.2711 - val_categorical_crossentropy: 0.2710 - val_categorical_accuracy: 0.9546 - val_categorical_hinge: 5.1404\n", "Epoch 2/20\n", "8955/8955 [==============================] - 60s 7ms/step - loss: 0.3110 - categorical_crossentropy: 0.3111 - categorical_accuracy: 0.9418 - categorical_hinge: 1.5468 - val_loss: 0.2106 - val_categorical_crossentropy: 0.2105 - val_categorical_accuracy: 0.9583 - val_categorical_hinge: 0.3215\n", "Epoch 3/20\n", "8955/8955 [==============================] - 60s 7ms/step - loss: 0.2253 - categorical_crossentropy: 0.2254 - categorical_accuracy: 0.9551 - categorical_hinge: 0.2726 - val_loss: 0.2068 - val_categorical_crossentropy: 0.2067 - val_categorical_accuracy: 0.9582 - val_categorical_hinge: 0.3656\n", "Epoch 4/20\n", "8955/8955 [==============================] - 63s 7ms/step - loss: 0.2059 - categorical_crossentropy: 0.2059 - categorical_accuracy: 0.9582 - categorical_hinge: 0.2735 - val_loss: 0.1957 - val_categorical_crossentropy: 0.1957 - val_categorical_accuracy: 0.9591 - val_categorical_hinge: 0.3622\n", "Epoch 5/20\n", "8955/8955 [==============================] - 82s 9ms/step - loss: 0.2016 - categorical_crossentropy: 0.2016 - categorical_accuracy: 0.9587 - categorical_hinge: 0.2638 - val_loss: 0.1970 - val_categorical_crossentropy: 0.1970 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.2868\n", "Epoch 6/20\n", "8955/8955 [==============================] - 72s 8ms/step - loss: 0.1995 - categorical_crossentropy: 0.1995 - categorical_accuracy: 0.9585 - categorical_hinge: 0.2506 - val_loss: 0.1973 - val_categorical_crossentropy: 0.1973 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.3453\n", "Epoch 7/20\n", "8955/8955 [==============================] - 75s 8ms/step - loss: 0.1991 - categorical_crossentropy: 0.1992 - categorical_accuracy: 0.9591 - categorical_hinge: 0.2518 - val_loss: 0.1969 - val_categorical_crossentropy: 0.1968 - val_categorical_accuracy: 0.9591 - val_categorical_hinge: 0.2474\n", "Epoch 8/20\n", "8955/8955 [==============================] - 75s 8ms/step - loss: 0.1983 - categorical_crossentropy: 0.1983 - categorical_accuracy: 0.9587 - categorical_hinge: 0.2560 - val_loss: 0.1957 - val_categorical_crossentropy: 0.1956 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.2912\n", "Epoch 9/20\n", "8955/8955 [==============================] - 76s 8ms/step - loss: 0.1965 - categorical_crossentropy: 0.1965 - categorical_accuracy: 0.9595 - categorical_hinge: 0.2328 - val_loss: 0.1957 - val_categorical_crossentropy: 0.1957 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.3268\n", "Epoch 10/20\n", "8955/8955 [==============================] - 77s 9ms/step - loss: 0.1961 - categorical_crossentropy: 0.1961 - categorical_accuracy: 0.9598 - categorical_hinge: 0.2375 - val_loss: 0.1954 - val_categorical_crossentropy: 0.1954 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.2518\n", "Epoch 11/20\n", "8955/8955 [==============================] - 72s 8ms/step - loss: 0.1980 - categorical_crossentropy: 0.1981 - categorical_accuracy: 0.9584 - categorical_hinge: 0.2387 - val_loss: 0.1948 - val_categorical_crossentropy: 0.1948 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.2795\n", "Epoch 12/20\n", "8955/8955 [==============================] - 69s 8ms/step - loss: 0.1959 - categorical_crossentropy: 0.1959 - categorical_accuracy: 0.9594 - categorical_hinge: 0.2465 - val_loss: 0.1950 - val_categorical_crossentropy: 0.1950 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.2824\n", "Epoch 13/20\n", "8955/8955 [==============================] - 65s 7ms/step - loss: 0.1960 - categorical_crossentropy: 0.1960 - categorical_accuracy: 0.9596 - categorical_hinge: 0.2413 - val_loss: 0.1955 - val_categorical_crossentropy: 0.1954 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.3859\n", "Epoch 14/20\n", "8955/8955 [==============================] - 68s 8ms/step - loss: 0.1959 - categorical_crossentropy: 0.1959 - categorical_accuracy: 0.9591 - categorical_hinge: 0.2280 - val_loss: 0.1941 - val_categorical_crossentropy: 0.1941 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.3046\n", "Epoch 15/20\n", "8955/8955 [==============================] - 77s 9ms/step - loss: 0.2031 - categorical_crossentropy: 0.2031 - categorical_accuracy: 0.9580 - categorical_hinge: 0.2336 - val_loss: 0.1949 - val_categorical_crossentropy: 0.1948 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.1873\n", "Epoch 16/20\n", "8955/8955 [==============================] - 68s 8ms/step - loss: 0.1966 - categorical_crossentropy: 0.1966 - categorical_accuracy: 0.9589 - categorical_hinge: 0.2361 - val_loss: 0.1938 - val_categorical_crossentropy: 0.1938 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.2056\n", "Epoch 17/20\n", "8955/8955 [==============================] - 67s 7ms/step - loss: 0.1924 - categorical_crossentropy: 0.1924 - categorical_accuracy: 0.9599 - categorical_hinge: 0.2510 - val_loss: 0.1934 - val_categorical_crossentropy: 0.1934 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.3111\n", "Epoch 18/20\n", "8955/8955 [==============================] - 67s 7ms/step - loss: 0.2079 - categorical_crossentropy: 0.2079 - categorical_accuracy: 0.9569 - categorical_hinge: 0.3273 - val_loss: 0.1930 - val_categorical_crossentropy: 0.1929 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.2956\n", "Epoch 19/20\n", "8955/8955 [==============================] - 67s 7ms/step - loss: 0.1923 - categorical_crossentropy: 0.1923 - categorical_accuracy: 0.9599 - categorical_hinge: 0.3248 - val_loss: 0.1924 - val_categorical_crossentropy: 0.1924 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.6417\n", "Epoch 20/20\n", "8955/8955 [==============================] - 68s 8ms/step - loss: 0.1950 - categorical_crossentropy: 0.1950 - categorical_accuracy: 0.9595 - categorical_hinge: 0.3199 - val_loss: 0.1935 - val_categorical_crossentropy: 0.1935 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.5098\n" ] } ], "source": [ "baseline_history = compile_and_fit(baseline_model, wg_sub)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_keys(['loss', 'categorical_crossentropy', 'categorical_accuracy', 'categorical_hinge', 'val_loss', 'val_categorical_crossentropy', 'val_categorical_accuracy', 'val_categorical_hinge'])" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "baseline_history.history.keys()" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "save_history(baseline_history.history, \".history/baseline.csv\")" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "def save_history(history, file_name: str = \"history.csv\"):\n", " pd.DataFrame.from_dict(history).to_csv(file_name, index=False)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-05-11 15:02:03.185843: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: .model/baseline/assets\n" ] } ], "source": [ "baseline_model.save(\".model/baseline\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data Prep\n", "\n", "All inputs follow: (batch_size, timesteps, input_dim)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-05-12 14:18:02.248350: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n", "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" ] }, { "data": { "text/plain": [ "WindowGenerator:\n", "\tTotal window size: 6\n", "\tInput indices: [0 1 2 3 4]\n", "\tLabel indices: [5]" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "wg = WindowGenerator(sleep_data)\n", "wg" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "4682" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(wg.training_ds)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ele[0].numpy()[0] = array([[ 0. , 8. , 18. , 0.68 , 0.057, 0.2 , 0.063],\n", " [ 1. , 8. , 19. , 0.652, 0.06 , 0.224, 0.064],\n", " [ 2. , 8. , 20. , 0.672, 0.059, 0.209, 0.06 ],\n", " [ 3. , 8. , 21. , 0.645, 0.056, 0.235, 0.064],\n", " [ 4. , 8. , 22. , 0.644, 0.054, 0.244, 0.058]],\n", " dtype=float32)\n", "ele[1].numpy()[0] = array([1., 0., 0., 0.], dtype=float32)\n" ] } ], "source": [ "for ele in wg.sample_ds.take(1):\n", " print(f\"{ele[0].numpy()[0] = }\")\n", " print(f\"{ele[1].numpy()[0] = }\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model 1: LSTM" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "# Hyper-parameters\n", "LSTM_UNITS = 8\n", "LSTM_LEARNING_RATE = 0.0001" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_2\"\n", "_________________________________________________________________\n", " Layer (type) Output Shape Param # \n", "=================================================================\n", " lstm (LSTM) (None, 8) 512 \n", " \n", " dense_2 (Dense) (None, 4) 36 \n", " \n", "=================================================================\n", "Total params: 548\n", "Trainable params: 548\n", "Non-trainable params: 0\n", "_________________________________________________________________\n", "None\n" ] } ], "source": [ "# Model Definition\n", "lstm_model = keras.Sequential()\n", "lstm_model.add(layers.Input(shape=(INPUT_TIME_STEP, INPUT_FEATURES_SIZE)))\n", "# lstm_model.add(layers.LSTM(LSTM_UNITS, stateful=False, return_sequences=True))\n", "# lstm_model.add(layers.LSTM(LSTM_UNITS, stateful=False, return_sequences=True))\n", "lstm_model.add(layers.LSTM(LSTM_UNITS, stateful=False, return_sequences=False))\n", "lstm_model.add(layers.Dense(SLEEP_STAGES))\n", "lstm_model.build()\n", "print(lstm_model.summary())" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "# Model Training\n", "lstm_loss = tf.losses.CategoricalCrossentropy(from_logits=True)\n", "lstm_optm = tf.optimizers.Adam(learning_rate=LSTM_LEARNING_RATE)\n", "lstm_metrics = [tf.keras.metrics.CategoricalCrossentropy(from_logits=True), tf.keras.metrics.BinaryCrossentropy(from_logits=True),tf.keras.metrics.Accuracy()]" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "9647/9647 [==============================] - 88s 9ms/step - loss: 0.6389 - categorical_crossentropy: 0.6388 - categorical_accuracy: 0.7669 - categorical_hinge: 0.5620 - val_loss: 0.2218 - val_categorical_crossentropy: 0.2217 - val_categorical_accuracy: 0.9551 - val_categorical_hinge: 0.2163\n" ] } ], "source": [ "lstm_history = compile_and_fit(model=lstm_model, window=wg, epochs=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model 2: GRU" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "# Hyper-paramters\n", "GRU_UNITS = 16" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_1\"\n", "_________________________________________________________________\n", " Layer (type) Output Shape Param # \n", "=================================================================\n", " gru (GRU) (None, 16) 1200 \n", " \n", " dense_1 (Dense) (None, 4) 68 \n", " \n", "=================================================================\n", "Total params: 1,268\n", "Trainable params: 1,268\n", "Non-trainable params: 0\n", "_________________________________________________________________\n", "None\n" ] } ], "source": [ "gru_model = keras.Sequential()\n", "gru_model.add(layers.Input(shape=(INPUT_TIME_STEP, INPUT_FEATURES_SIZE)))\n", "gru_model.add(layers.GRU(GRU_UNITS))\n", "gru_model.add(layers.Dense(SLEEP_STAGES))\n", "gru_model.build()\n", "print(gru_model.summary())" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "9647/9647 [==============================] - 97s 10ms/step - loss: 0.4700 - categorical_crossentropy: 0.4700 - categorical_accuracy: 0.8416 - categorical_hinge: 0.4278 - val_loss: 0.2088 - val_categorical_crossentropy: 0.2087 - val_categorical_accuracy: 0.9595 - val_categorical_hinge: 0.1954\n" ] } ], "source": [ "gru_history = compile_and_fit(model=gru_model, window=wg, epochs=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model 3: Attention Mechanism" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "ATTENTION_UNITS = 32" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential\"\n", "_________________________________________________________________\n", " Layer (type) Output Shape Param # \n", "=================================================================\n", " custom_attention (CustomAtt (None, 32) 497 \n", " ention) \n", " \n", " dense (Dense) (None, 4) 132 \n", " \n", "=================================================================\n", "Total params: 629\n", "Trainable params: 629\n", "Non-trainable params: 0\n", "_________________________________________________________________\n", "None\n" ] } ], "source": [ "am_model = keras.Sequential()\n", "am_model.add(layers.Input(shape=(INPUT_TIME_STEP, INPUT_FEATURES_SIZE)))\n", "am_model.add(CustomAttention(ATTENTION_UNITS))\n", "am_model.add(layers.Dense(SLEEP_STAGES))\n", "am_model.build()\n", "print(am_model.summary())" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x12aaebe50>" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "am_model.load_weights(\"TEST\")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "am_history = compile_and_fit(model=am_model, window=wg, epochs=0)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1442/1442 [==============================] - 14s 4ms/step - loss: 0.1957 - categorical_crossentropy: 0.1957 - categorical_accuracy: 0.9598 - categorical_hinge: 0.1911\n" ] }, { "data": { "text/plain": [ "[0.19572481513023376,\n", " 0.19574123620986938,\n", " 0.9597873687744141,\n", " 0.1910863220691681]" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "am_model.evaluate(wg.testing_ds)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 0.59796184, -0.20803624, -0.2916372 , -0.7170148 ]],\n", " dtype=float32)" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "answer = am_model.predict(np.array(\n", " [[\n", " [ 150., 10., 20., 0.5, 0.1, 0., 0.4],\n", " [ 151., 10., 21., 0.5, 0.2, 0., 0.3],\n", " [ 152., 10., 22., 0.5, 0.3, 0., 0.2],\n", " [ 153., 10., 23., 0.5, 0.4, 0., 0.1],\n", " [ 154., 10., 24., 0.2, 0.1, 0., 0.],\n", " [ 155., 10., 25., 0.2, 0.2, 0., 0.],\n", " [ 156., 10., 26., 0.2, 0.2, 0.1, 0.],\n", " [ 157., 10., 27., 0.2, 0.25, 0.15, 0.],\n", " [ 158., 10., 28., 0.2, 0.3, 0.2, 0.],\n", " [ 159., 10., 29., 0.2, 0.3, 0.3, 0.]\n", " ]]\n", "))\n", "norm = np.linalg.norm(answer)\n", "normalized_answer = answer/norm\n", "normalized_answer" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "pd.DataFrame.from_dict(am_history.history).to_csv(f\"{HISTORY_DIR}/am_220512.csv\", index=False)" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Help on method save_weights in module keras.engine.training:\n", "\n", "save_weights(filepath, overwrite=True, save_format=None, options=None) method of keras.engine.sequential.Sequential instance\n", " Saves all layer weights.\n", " \n", " Either saves in HDF5 or in TensorFlow format based on the `save_format`\n", " argument.\n", " \n", " When saving in HDF5 format, the weight file has:\n", " - `layer_names` (attribute), a list of strings\n", " (ordered names of model layers).\n", " - For every layer, a `group` named `layer.name`\n", " - For every such layer group, a group attribute `weight_names`,\n", " a list of strings\n", " (ordered names of weights tensor of the layer).\n", " - For every weight in the layer, a dataset\n", " storing the weight value, named after the weight tensor.\n", " \n", " When saving in TensorFlow format, all objects referenced by the network are\n", " saved in the same format as `tf.train.Checkpoint`, including any `Layer`\n", " instances or `Optimizer` instances assigned to object attributes. For\n", " networks constructed from inputs and outputs using `tf.keras.Model(inputs,\n", " outputs)`, `Layer` instances used by the network are tracked/saved\n", " automatically. For user-defined classes which inherit from `tf.keras.Model`,\n", " `Layer` instances must be assigned to object attributes, typically in the\n", " constructor. See the documentation of `tf.train.Checkpoint` and\n", " `tf.keras.Model` for details.\n", " \n", " While the formats are the same, do not mix `save_weights` and\n", " `tf.train.Checkpoint`. Checkpoints saved by `Model.save_weights` should be\n", " loaded using `Model.load_weights`. Checkpoints saved using\n", " `tf.train.Checkpoint.save` should be restored using the corresponding\n", " `tf.train.Checkpoint.restore`. Prefer `tf.train.Checkpoint` over\n", " `save_weights` for training checkpoints.\n", " \n", " The TensorFlow format matches objects and variables by starting at a root\n", " object, `self` for `save_weights`, and greedily matching attribute\n", " names. For `Model.save` this is the `Model`, and for `Checkpoint.save` this\n", " is the `Checkpoint` even if the `Checkpoint` has a model attached. This\n", " means saving a `tf.keras.Model` using `save_weights` and loading into a\n", " `tf.train.Checkpoint` with a `Model` attached (or vice versa) will not match\n", " the `Model`'s variables. See the\n", " [guide to training checkpoints](https://www.tensorflow.org/guide/checkpoint)\n", " for details on the TensorFlow format.\n", " \n", " Args:\n", " filepath: String or PathLike, path to the file to save the weights to.\n", " When saving in TensorFlow format, this is the prefix used for\n", " checkpoint files (multiple files are generated). Note that the '.h5'\n", " suffix causes weights to be saved in HDF5 format.\n", " overwrite: Whether to silently overwrite any existing file at the\n", " target location, or provide the user with a manual prompt.\n", " save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or\n", " '.keras' will default to HDF5 if `save_format` is `None`. Otherwise\n", " `None` defaults to 'tf'.\n", " options: Optional `tf.train.CheckpointOptions` object that specifies\n", " options for saving weights.\n", " \n", " Raises:\n", " ImportError: If `h5py` is not available when attempting to save in HDF5\n", " format.\n", "\n" ] } ], "source": [ "help(am_model.save_weights)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-05-12 12:12:02.267722: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.\n", "WARNING:absl:Found untraced functions such as attention_score_vec_layer_call_fn, attention_score_vec_layer_call_and_return_conditional_losses, last_hidden_state_layer_call_fn, last_hidden_state_layer_call_and_return_conditional_losses, attention_score_layer_call_fn while saving (showing 5 of 14). These functions will not be directly callable after loading.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: .model/attention/assets\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:tensorflow:Assets written to: .model/attention/assets\n" ] } ], "source": [ "am_model.save(f\"{MODEL_DIR}/attention\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model Head-to-Head testing" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "HISTORY_DIR = \".history\"\n", "MODEL_DIR = \".model\"" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "base_history = pd.read_csv(f\"{HISTORY_DIR}/baseline.csv\")\n", "lstm_history = pd.read_csv(f\"{HISTORY_DIR}/lstm.csv\")\n", "gru_history = pd.read_csv(f\"{HISTORY_DIR}/gru.csv\")\n", "attn_history = pd.read_csv(f\"{HISTORY_DIR}/attention.csv\")\n", "history_columns = [\"categorical_crossentropy\", \"val_categorical_crossentropy\"] # \"categorical_accuracy\"" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<AxesSubplot:>" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "base_history[history_columns].plot()\n", "lstm_history[history_columns].plot()\n", "gru_history[history_columns].plot()\n", "attn_history[history_columns].plot()" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "base_model = tf.keras.models.load_model(f\"{MODEL_DIR}/baseline\")\n", "lstm_model = tf.keras.models.load_model(f\"{MODEL_DIR}/lstm\")\n", "gru_model = tf.keras.models.load_model(f\"{MODEL_DIR}/gru\")" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "ename": "RuntimeError", "evalue": "Unable to restore object of class 'Attention' likely due to name conflict with built-in Keras class '<class 'keras.layers.dense_attention.Attention'>'. To override the built-in Keras definition of the object, decorate your class with `@keras.utils.register_keras_serializable` and include that file in your program, or pass your class in a `keras.utils.CustomObjectScope` that wraps this load call.", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m/Users/nowadmin/Documents/School Folder/CS 437/Lab/Final Project/tf_model.ipynb Cell 65'\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> <a href='vscode-notebook-cell:/Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/tf_model.ipynb#ch0000068?line=0'>1</a>\u001b[0m attn_model \u001b[39m=\u001b[39m tf\u001b[39m.\u001b[39;49mkeras\u001b[39m.\u001b[39;49mmodels\u001b[39m.\u001b[39;49mload_model(\u001b[39mf\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39m{\u001b[39;49;00mMODEL_DIR\u001b[39m}\u001b[39;49;00m\u001b[39m/attention\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n", "File \u001b[0;32m~/Documents/School Folder/CS 437/Lab/Final Project/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py:67\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py?line=64'>65</a>\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mException\u001b[39;00m \u001b[39mas\u001b[39;00m e: \u001b[39m# pylint: disable=broad-except\u001b[39;00m\n\u001b[1;32m <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py?line=65'>66</a>\u001b[0m filtered_tb \u001b[39m=\u001b[39m _process_traceback_frames(e\u001b[39m.\u001b[39m__traceback__)\n\u001b[0;32m---> <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py?line=66'>67</a>\u001b[0m \u001b[39mraise\u001b[39;00m e\u001b[39m.\u001b[39mwith_traceback(filtered_tb) \u001b[39mfrom\u001b[39;00m \u001b[39mNone\u001b[39m\n\u001b[1;32m <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py?line=67'>68</a>\u001b[0m \u001b[39mfinally\u001b[39;00m:\n\u001b[1;32m <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py?line=68'>69</a>\u001b[0m \u001b[39mdel\u001b[39;00m filtered_tb\n", "File \u001b[0;32m~/Documents/School Folder/CS 437/Lab/Final Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py:532\u001b[0m, in \u001b[0;36mKerasObjectLoader._revive_layer_or_model_from_config\u001b[0;34m(self, metadata, node_id)\u001b[0m\n\u001b[1;32m <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py?line=529'>530</a>\u001b[0m builtin_layer \u001b[39m=\u001b[39m layers_module\u001b[39m.\u001b[39mget_builtin_layer(class_name)\n\u001b[1;32m <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py?line=530'>531</a>\u001b[0m \u001b[39mif\u001b[39;00m builtin_layer:\n\u001b[0;32m--> <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py?line=531'>532</a>\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\n\u001b[1;32m <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py?line=532'>533</a>\u001b[0m \u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39mUnable to restore object of class \u001b[39m\u001b[39m\\'\u001b[39;00m\u001b[39m{\u001b[39;00mclass_name\u001b[39m}\u001b[39;00m\u001b[39m\\'\u001b[39;00m\u001b[39m likely due to \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py?line=533'>534</a>\u001b[0m \u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39mname conflict with built-in Keras class \u001b[39m\u001b[39m\\'\u001b[39;00m\u001b[39m{\u001b[39;00mbuiltin_layer\u001b[39m}\u001b[39;00m\u001b[39m\\'\u001b[39;00m\u001b[39m. To \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py?line=534'>535</a>\u001b[0m \u001b[39m'\u001b[39m\u001b[39moverride the built-in Keras definition of the object, decorate \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py?line=535'>536</a>\u001b[0m \u001b[39m'\u001b[39m\u001b[39myour class with `@keras.utils.register_keras_serializable` and \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py?line=536'>537</a>\u001b[0m \u001b[39m'\u001b[39m\u001b[39minclude that file in your program, or pass your class in a \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py?line=537'>538</a>\u001b[0m \u001b[39m'\u001b[39m\u001b[39m`keras.utils.CustomObjectScope` that wraps this load call.\u001b[39m\u001b[39m'\u001b[39m) \u001b[39mfrom\u001b[39;00m \u001b[39me\u001b[39;00m\n\u001b[1;32m <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py?line=538'>539</a>\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py?line=539'>540</a>\u001b[0m \u001b[39mraise\u001b[39;00m\n", "\u001b[0;31mRuntimeError\u001b[0m: Unable to restore object of class 'Attention' likely due to name conflict with built-in Keras class '<class 'keras.layers.dense_attention.Attention'>'. To override the built-in Keras definition of the object, decorate your class with `@keras.utils.register_keras_serializable` and include that file in your program, or pass your class in a `keras.utils.CustomObjectScope` that wraps this load call." ] } ], "source": [ "attn_model = tf.keras.models.load_model(f\"{MODEL_DIR}/attention\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def model_time():\n", " pass" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Scratch" ] } ], "metadata": { "kernelspec": { "display_name": "final_lab", "language": "python", "name": "final_lab" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }