{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Notebook for training and testing AI models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "import pandas as pd\n",
    "from attention import CustomAttention\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.8.0\n"
     ]
    }
   ],
   "source": [
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CONSTANTS\n",
    "RAW_SLEEP_DATA_PATH = \".data/raw_bed_sleep-state.csv\"\n",
    "CLEANED_SLEEP_DATA_PATH = \".data/clean_bed_sleep-state.csv\"\n",
    "SLEEP_DATA_PATH = \".data/sleep_data_simple.csv\"\n",
    "UPDATED_SLEEP_DATA_PATH = \".data/updated_sleep_data.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Parameters and Hyper-parameters\n",
    "SLEEP_STAGES = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cleaning Raw Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "import datetime\n",
    "import itertools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "datetime.datetime(2022, 4, 21, 10, 19, tzinfo=datetime.timezone(datetime.timedelta(seconds=7200)))"
      ]
     },
     "execution_count": 104,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datetime.datetime.strptime(\"2022-04-21T10:18:00+02:00\",\"%Y-%m-%dT%H:%M:%S%z\") + datetime.timedelta(minutes=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [],
   "source": [
    "def stage_probability(stage, to_test):\n",
    "    return 1.0 if stage == to_test else 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "6.833333333333333"
      ]
     },
     "execution_count": 150,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "((start_time - in_bed_time).seconds)/3600"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 201,
   "metadata": {},
   "outputs": [],
   "source": [
    "cleaned_info = []\n",
    "date_seen = set()\n",
    "previous_duration = 60\n",
    "with open(RAW_SLEEP_DATA_PATH, mode ='r') as raw_file:\n",
    "    csvFile = csv.reader(raw_file)\n",
    "    # max_count = 1\n",
    "    # stuff = set()\n",
    "    in_bed_time = None\n",
    "    current_sleep_id = -1\n",
    "    for index, lines in enumerate(csvFile):\n",
    "        if index == 0:\n",
    "            cleaned_info.append([\n",
    "                \"sleep_id\",\n",
    "                \"sleep_begin\",\n",
    "                \"stage_start\",\n",
    "                \"time_since_begin_sec\",\n",
    "                \"stage_duration_sec\",\n",
    "                \"stage_end\", \n",
    "                \"stage_value\",\n",
    "                \"awake_probability\",\n",
    "                \"light_probability\",\n",
    "                \"deep_probability\",\n",
    "                \"rem_probability\",\n",
    "            ])\n",
    "            continue\n",
    "        start_time = datetime.datetime.strptime(lines[0],\"%Y-%m-%dT%H:%M:%S%z\")\n",
    "        if start_time in date_seen:\n",
    "            continue\n",
    "        date_seen.add(start_time)\n",
    "        if not in_bed_time or in_bed_time > start_time:\n",
    "            current_sleep_id += 1\n",
    "            in_bed_time = start_time\n",
    "        # for duration, stage in enumerate(\n",
    "        # for offset, (duration, stage) in enumerate(\n",
    "        #     zip(\n",
    "        #         # itertools.accumulate(lines[1].strip(\"[]\").split(\",\"), lambda x,y: int(x)+int(y)//60, initial = 0), \n",
    "        #         map(int, lines[1].strip(\"[]\").split(\",\"))\n",
    "        #         map(int, lines[2].strip(\"[]\").split(\",\"))\n",
    "        #     )\n",
    "        #         # map(int, lines[2].strip(\"[]\").split(\",\"))\n",
    "        # ):\n",
    "        for offset, (duration, stage) in enumerate(zip(map(int, lines[1].strip(\"[]\").split(\",\")), map(int, lines[2].strip(\"[]\").split(\",\")))):\n",
    "            # print(f\"{(index, subindex) = }, {duration = }, {stage = }\")\n",
    "            # print(f\"{(index, duration) = } {stage = }\")\n",
    "            current_time = start_time + datetime.timedelta(seconds=offset*previous_duration)\n",
    "            cleaned_info.append([\n",
    "                current_sleep_id,\n",
    "                in_bed_time,\n",
    "                current_time, \n",
    "                (current_time - in_bed_time).seconds,\n",
    "                duration, \n",
    "                current_time + datetime.timedelta(seconds=duration), \n",
    "                stage,\n",
    "                stage_probability(0, stage),\n",
    "                stage_probability(1, stage),\n",
    "                stage_probability(2, stage),\n",
    "                stage_probability(3, stage),\n",
    "            ])\n",
    "            previous_duration = duration\n",
    "            # print(f\"{(index, subindex) = }, {val = }\")\n",
    "        # print(list())\n",
    "        # if index >= max_count:\n",
    "        #     break\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 202,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished Writing Cleaned Data\n"
     ]
    }
   ],
   "source": [
    "with open(CLEANED_SLEEP_DATA_PATH, 'w') as clean_file:\n",
    "    write = csv.writer(clean_file)\n",
    "    write.writerows(cleaned_info)\n",
    "print(\"Finished Writing Cleaned Data\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating DataFrame from clean raw data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the cleaned data\n",
    "sleep_df_raw = pd.read_csv(CLEANED_SLEEP_DATA_PATH)#, parse_dates=[\"start\", \"end\"], infer_datetime_format=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess data: \n",
    "#   1. convert to datetime\n",
    "sleep_df_raw[\"sleep_begin\"] = pd.to_datetime(sleep_df_raw[\"sleep_begin\"], utc=True)\n",
    "sleep_df_raw[\"stage_start\"] = pd.to_datetime(sleep_df_raw[\"stage_start\"], utc=True)\n",
    "sleep_df_raw[\"stage_end\"] = pd.to_datetime(sleep_df_raw[\"stage_end\"], utc=True)\n",
    "#   2. Separate time, hour and minute\n",
    "#   MAYBE 3. smaller units: int16 or int8 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_minute(row, index):\n",
    "    return row[index].time().minute\n",
    "\n",
    "def get_hour(row, index):\n",
    "    return row[index].time().hour"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "sleep_df_raw[\"stage_start_hour\"] = sleep_df_raw.apply (lambda row: get_hour(row, \"stage_start\"), axis=1)\n",
    "sleep_df_raw[\"stage_start_minute\"] = sleep_df_raw.apply (lambda row: get_minute(row, \"stage_start\"), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 551042 entries, 0 to 551041\n",
      "Data columns (total 13 columns):\n",
      " #   Column                Non-Null Count   Dtype              \n",
      "---  ------                --------------   -----              \n",
      " 0   sleep_id              551042 non-null  int64              \n",
      " 1   sleep_begin           551042 non-null  datetime64[ns, UTC]\n",
      " 2   stage_start           551042 non-null  datetime64[ns, UTC]\n",
      " 3   time_since_begin_sec  551042 non-null  int64              \n",
      " 4   stage_duration_sec    551042 non-null  int64              \n",
      " 5   stage_end             551042 non-null  datetime64[ns, UTC]\n",
      " 6   stage_value           551042 non-null  int64              \n",
      " 7   awake_probability     551042 non-null  float64            \n",
      " 8   light_probability     551042 non-null  float64            \n",
      " 9   deep_probability      551042 non-null  float64            \n",
      " 10  rem_probability       551042 non-null  float64            \n",
      " 11  stage_start_hour      551042 non-null  int64              \n",
      " 12  stage_start_minute    551042 non-null  int64              \n",
      "dtypes: datetime64[ns, UTC](3), float64(4), int64(6)\n",
      "memory usage: 54.7 MB\n"
     ]
    }
   ],
   "source": [
    "sleep_df_raw.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sleep_id</th>\n",
       "      <th>sleep_begin</th>\n",
       "      <th>stage_start</th>\n",
       "      <th>time_since_begin_sec</th>\n",
       "      <th>stage_duration_sec</th>\n",
       "      <th>stage_end</th>\n",
       "      <th>stage_value</th>\n",
       "      <th>awake_probability</th>\n",
       "      <th>light_probability</th>\n",
       "      <th>deep_probability</th>\n",
       "      <th>rem_probability</th>\n",
       "      <th>stage_start_hour</th>\n",
       "      <th>stage_start_minute</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:19:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8</td>\n",
       "      <td>18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:19:00+00:00</td>\n",
       "      <td>60</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:20:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8</td>\n",
       "      <td>19</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:20:00+00:00</td>\n",
       "      <td>120</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:21:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8</td>\n",
       "      <td>20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:21:00+00:00</td>\n",
       "      <td>180</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:22:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8</td>\n",
       "      <td>21</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:22:00+00:00</td>\n",
       "      <td>240</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:23:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8</td>\n",
       "      <td>22</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551037</th>\n",
       "      <td>1132</td>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:17:00+00:00</td>\n",
       "      <td>25560</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:18:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>13</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551038</th>\n",
       "      <td>1132</td>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:18:00+00:00</td>\n",
       "      <td>25620</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:19:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>13</td>\n",
       "      <td>18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551039</th>\n",
       "      <td>1132</td>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:19:00+00:00</td>\n",
       "      <td>25680</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:20:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>13</td>\n",
       "      <td>19</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551040</th>\n",
       "      <td>1132</td>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:20:00+00:00</td>\n",
       "      <td>25740</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:21:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>13</td>\n",
       "      <td>20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551041</th>\n",
       "      <td>1132</td>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:21:00+00:00</td>\n",
       "      <td>25800</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:22:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>13</td>\n",
       "      <td>21</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>551042 rows × 13 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        sleep_id               sleep_begin               stage_start  \\\n",
       "0              0 2022-04-21 08:18:00+00:00 2022-04-21 08:18:00+00:00   \n",
       "1              0 2022-04-21 08:18:00+00:00 2022-04-21 08:19:00+00:00   \n",
       "2              0 2022-04-21 08:18:00+00:00 2022-04-21 08:20:00+00:00   \n",
       "3              0 2022-04-21 08:18:00+00:00 2022-04-21 08:21:00+00:00   \n",
       "4              0 2022-04-21 08:18:00+00:00 2022-04-21 08:22:00+00:00   \n",
       "...          ...                       ...                       ...   \n",
       "551037      1132 2019-02-11 06:11:00+00:00 2019-02-11 13:17:00+00:00   \n",
       "551038      1132 2019-02-11 06:11:00+00:00 2019-02-11 13:18:00+00:00   \n",
       "551039      1132 2019-02-11 06:11:00+00:00 2019-02-11 13:19:00+00:00   \n",
       "551040      1132 2019-02-11 06:11:00+00:00 2019-02-11 13:20:00+00:00   \n",
       "551041      1132 2019-02-11 06:11:00+00:00 2019-02-11 13:21:00+00:00   \n",
       "\n",
       "        time_since_begin_sec  stage_duration_sec                 stage_end  \\\n",
       "0                          0                  60 2022-04-21 08:19:00+00:00   \n",
       "1                         60                  60 2022-04-21 08:20:00+00:00   \n",
       "2                        120                  60 2022-04-21 08:21:00+00:00   \n",
       "3                        180                  60 2022-04-21 08:22:00+00:00   \n",
       "4                        240                  60 2022-04-21 08:23:00+00:00   \n",
       "...                      ...                 ...                       ...   \n",
       "551037                 25560                  60 2019-02-11 13:18:00+00:00   \n",
       "551038                 25620                  60 2019-02-11 13:19:00+00:00   \n",
       "551039                 25680                  60 2019-02-11 13:20:00+00:00   \n",
       "551040                 25740                  60 2019-02-11 13:21:00+00:00   \n",
       "551041                 25800                  60 2019-02-11 13:22:00+00:00   \n",
       "\n",
       "        stage_value  awake_probability  light_probability  deep_probability  \\\n",
       "0                 0                1.0                0.0               0.0   \n",
       "1                 0                1.0                0.0               0.0   \n",
       "2                 0                1.0                0.0               0.0   \n",
       "3                 0                1.0                0.0               0.0   \n",
       "4                 0                1.0                0.0               0.0   \n",
       "...             ...                ...                ...               ...   \n",
       "551037            1                0.0                1.0               0.0   \n",
       "551038            1                0.0                1.0               0.0   \n",
       "551039            1                0.0                1.0               0.0   \n",
       "551040            1                0.0                1.0               0.0   \n",
       "551041            1                0.0                1.0               0.0   \n",
       "\n",
       "        rem_probability  stage_start_hour  stage_start_minute  \n",
       "0                   0.0                 8                  18  \n",
       "1                   0.0                 8                  19  \n",
       "2                   0.0                 8                  20  \n",
       "3                   0.0                 8                  21  \n",
       "4                   0.0                 8                  22  \n",
       "...                 ...               ...                 ...  \n",
       "551037              0.0                13                  17  \n",
       "551038              0.0                13                  18  \n",
       "551039              0.0                13                  19  \n",
       "551040              0.0                13                  20  \n",
       "551041              0.0                13                  21  \n",
       "\n",
       "[551042 rows x 13 columns]"
      ]
     },
     "execution_count": 84,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sleep_df_raw"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "sleep_data = sleep_df_raw[[\"sleep_id\", \"stage_start_hour\", \"stage_start_minute\", \"awake_probability\", \"rem_probability\",\"light_probability\", \"deep_probability\"]]\n",
    "sleep_data.insert(loc=1, column=\"minutes_since_begin\" , value= sleep_df_raw[\"time_since_begin_sec\"]//60)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   sleep_id  minutes_since_begin  stage_start_hour  stage_start_minute  \\\n",
      "0         0                    0                 8                  18   \n",
      "1         0                    1                 8                  19   \n",
      "2         0                    2                 8                  20   \n",
      "3         0                    3                 8                  21   \n",
      "4         0                    4                 8                  22   \n",
      "\n",
      "   awake_probability  rem_probability  light_probability  deep_probability  \n",
      "0                1.0              0.0                0.0               0.0  \n",
      "1                1.0              0.0                0.0               0.0  \n",
      "2                1.0              0.0                0.0               0.0  \n",
      "3                1.0              0.0                0.0               0.0  \n",
      "4                1.0              0.0                0.0               0.0  \n",
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 551042 entries, 0 to 551041\n",
      "Data columns (total 8 columns):\n",
      " #   Column               Non-Null Count   Dtype  \n",
      "---  ------               --------------   -----  \n",
      " 0   sleep_id             551042 non-null  int64  \n",
      " 1   minutes_since_begin  551042 non-null  int64  \n",
      " 2   stage_start_hour     551042 non-null  int64  \n",
      " 3   stage_start_minute   551042 non-null  int64  \n",
      " 4   awake_probability    551042 non-null  float64\n",
      " 5   rem_probability      551042 non-null  float64\n",
      " 6   light_probability    551042 non-null  float64\n",
      " 7   deep_probability     551042 non-null  float64\n",
      "dtypes: float64(4), int64(4)\n",
      "memory usage: 33.6 MB\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "print(sleep_data.head())\n",
    "print(sleep_data.info())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "sleep_data.to_csv(\".data/sleep_data_simple.csv\", index=False, index_label=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Development"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "TEST_SIZE = 365//2\n",
    "VALIDATION_SIZE = 365\n",
    "\n",
    "BATCH_SIZE = 64\n",
    "INPUT_TIME_STEP = 5 # in minutes\n",
    "INPUT_FEATURES_SIZE = 7\n",
    "MAX_EPOCHS = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "SAMPLE_COUNT = 1000\n",
    "#   A R L D\n",
    "# A\n",
    "# R\n",
    "# L\n",
    "# D\n",
    "CONFUSION_MATRIX = np.array(\n",
    "    [\n",
    "        [66.2, 5.0, 22.5, 6.2],\n",
    "        [1.6, 60.7, 33.0, 4.7],\n",
    "        [3.8, 22.3, 55.4, 18.5],\n",
    "        [0.0, 1.3, 26.7, 72.0],\n",
    "    ]\n",
    ")/100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sleep_data = pd.read_csv(SLEEP_DATA_PATH)\n",
    "sleep_data = pd.read_csv(UPDATED_SLEEP_DATA_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sleep_id</th>\n",
       "      <th>minutes_since_begin</th>\n",
       "      <th>stage_start_hour</th>\n",
       "      <th>stage_start_minute</th>\n",
       "      <th>awake_probability_noisy</th>\n",
       "      <th>rem_probability_noisy</th>\n",
       "      <th>light_probability_noisy</th>\n",
       "      <th>deep_probability_noisy</th>\n",
       "      <th>awake_probability_original</th>\n",
       "      <th>rem_probability_original</th>\n",
       "      <th>light_probability_original</th>\n",
       "      <th>deep_probability_original</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>18.0</td>\n",
       "      <td>0.680</td>\n",
       "      <td>0.057</td>\n",
       "      <td>0.200</td>\n",
       "      <td>0.063</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>0.652</td>\n",
       "      <td>0.060</td>\n",
       "      <td>0.224</td>\n",
       "      <td>0.064</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>20.0</td>\n",
       "      <td>0.672</td>\n",
       "      <td>0.059</td>\n",
       "      <td>0.209</td>\n",
       "      <td>0.060</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>21.0</td>\n",
       "      <td>0.645</td>\n",
       "      <td>0.056</td>\n",
       "      <td>0.235</td>\n",
       "      <td>0.064</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>0.644</td>\n",
       "      <td>0.054</td>\n",
       "      <td>0.244</td>\n",
       "      <td>0.058</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551037</th>\n",
       "      <td>1132.0</td>\n",
       "      <td>426.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>17.0</td>\n",
       "      <td>0.041</td>\n",
       "      <td>0.193</td>\n",
       "      <td>0.576</td>\n",
       "      <td>0.190</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551038</th>\n",
       "      <td>1132.0</td>\n",
       "      <td>427.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>18.0</td>\n",
       "      <td>0.027</td>\n",
       "      <td>0.209</td>\n",
       "      <td>0.563</td>\n",
       "      <td>0.201</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551039</th>\n",
       "      <td>1132.0</td>\n",
       "      <td>428.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>0.032</td>\n",
       "      <td>0.220</td>\n",
       "      <td>0.574</td>\n",
       "      <td>0.174</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551040</th>\n",
       "      <td>1132.0</td>\n",
       "      <td>429.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>20.0</td>\n",
       "      <td>0.036</td>\n",
       "      <td>0.256</td>\n",
       "      <td>0.530</td>\n",
       "      <td>0.178</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551041</th>\n",
       "      <td>1132.0</td>\n",
       "      <td>430.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>21.0</td>\n",
       "      <td>0.033</td>\n",
       "      <td>0.205</td>\n",
       "      <td>0.571</td>\n",
       "      <td>0.191</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>551042 rows × 12 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        sleep_id  minutes_since_begin  stage_start_hour  stage_start_minute  \\\n",
       "0            0.0                  0.0               8.0                18.0   \n",
       "1            0.0                  1.0               8.0                19.0   \n",
       "2            0.0                  2.0               8.0                20.0   \n",
       "3            0.0                  3.0               8.0                21.0   \n",
       "4            0.0                  4.0               8.0                22.0   \n",
       "...          ...                  ...               ...                 ...   \n",
       "551037    1132.0                426.0              13.0                17.0   \n",
       "551038    1132.0                427.0              13.0                18.0   \n",
       "551039    1132.0                428.0              13.0                19.0   \n",
       "551040    1132.0                429.0              13.0                20.0   \n",
       "551041    1132.0                430.0              13.0                21.0   \n",
       "\n",
       "        awake_probability_noisy  rem_probability_noisy  \\\n",
       "0                         0.680                  0.057   \n",
       "1                         0.652                  0.060   \n",
       "2                         0.672                  0.059   \n",
       "3                         0.645                  0.056   \n",
       "4                         0.644                  0.054   \n",
       "...                         ...                    ...   \n",
       "551037                    0.041                  0.193   \n",
       "551038                    0.027                  0.209   \n",
       "551039                    0.032                  0.220   \n",
       "551040                    0.036                  0.256   \n",
       "551041                    0.033                  0.205   \n",
       "\n",
       "        light_probability_noisy  deep_probability_noisy  \\\n",
       "0                         0.200                   0.063   \n",
       "1                         0.224                   0.064   \n",
       "2                         0.209                   0.060   \n",
       "3                         0.235                   0.064   \n",
       "4                         0.244                   0.058   \n",
       "...                         ...                     ...   \n",
       "551037                    0.576                   0.190   \n",
       "551038                    0.563                   0.201   \n",
       "551039                    0.574                   0.174   \n",
       "551040                    0.530                   0.178   \n",
       "551041                    0.571                   0.191   \n",
       "\n",
       "        awake_probability_original  rem_probability_original  \\\n",
       "0                              1.0                       0.0   \n",
       "1                              1.0                       0.0   \n",
       "2                              1.0                       0.0   \n",
       "3                              1.0                       0.0   \n",
       "4                              1.0                       0.0   \n",
       "...                            ...                       ...   \n",
       "551037                         0.0                       0.0   \n",
       "551038                         0.0                       0.0   \n",
       "551039                         0.0                       0.0   \n",
       "551040                         0.0                       0.0   \n",
       "551041                         0.0                       0.0   \n",
       "\n",
       "        light_probability_original  deep_probability_original  \n",
       "0                              0.0                        0.0  \n",
       "1                              0.0                        0.0  \n",
       "2                              0.0                        0.0  \n",
       "3                              0.0                        0.0  \n",
       "4                              0.0                        0.0  \n",
       "...                            ...                        ...  \n",
       "551037                         1.0                        0.0  \n",
       "551038                         1.0                        0.0  \n",
       "551039                         1.0                        0.0  \n",
       "551040                         1.0                        0.0  \n",
       "551041                         1.0                        0.0  \n",
       "\n",
       "[551042 rows x 12 columns]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sleep_data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create Randmonizations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "def softmax(arr):\n",
    "    val = np.exp(arr)\n",
    "    return val / sum(val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 133,
   "metadata": {},
   "outputs": [],
   "source": [
    "def noisy_randomizer(confusion_matrix_index: int, sample_count: int = SAMPLE_COUNT, confusion_matrix = CONFUSION_MATRIX):\n",
    "    return np.random.multinomial(sample_count, confusion_matrix[confusion_matrix_index], size=1)[0]/sample_count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 173,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_updated(dataframe):\n",
    "    dataframe_array = dataframe.to_numpy()\n",
    "    updated_probs = np.array(list(map(noisy_randomizer, np.argmax(dataframe_array[:, -4:], axis=1))))\n",
    "    columns_tmp = dataframe.columns.to_list() \n",
    "    noisy_column = list(map(lambda name: f\"{name}_noisy\", columns_tmp[-4:]))\n",
    "    og_column = list(map(lambda name: f\"{name}_original\", columns_tmp[-4:]))\n",
    "    return pd.DataFrame(np.concatenate([dataframe_array[:, :-4], updated_probs, dataframe_array[:, -4:]], axis=1), columns = columns_tmp[:-4]+noisy_column+og_column)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 174,
   "metadata": {},
   "outputs": [],
   "source": [
    "sleep_data_updated = create_updated(sleep_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 176,
   "metadata": {},
   "outputs": [],
   "source": [
    "sleep_data_updated.to_csv(UPDATED_SLEEP_DATA_PATH, index=False, index_label=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Helper functions and class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def training_test_split_by_unique_index(data, index: str, test_size: int = 10):\n",
    "    test_ids = np.random.choice(data[index].unique(), size = test_size, replace=False)\n",
    "    return data[~data[index].isin(test_ids)], data[data[index].isin(test_ids)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Adapted from https://www.tensorflow.org/tutorials/structured_data/time_series\n",
    "class WindowGenerator():\n",
    "    def __init__(self, data, index: str = \"sleep_id\", input_width: int = INPUT_TIME_STEP, validation_size: int = VALIDATION_SIZE, test_size: int = TEST_SIZE, input_feature_slice: slice = slice(1,-4), label_feature_slice: slice = slice(-4,100), generate_data_now: bool = True):\n",
    "        # Partition data\n",
    "        self.training, self.testing = training_test_split_by_unique_index(data, index, test_size)\n",
    "        self.training, self.validation = training_test_split_by_unique_index(self.training, index, validation_size)\n",
    "\n",
    "        # Window paramters\n",
    "        self.input_width = input_width\n",
    "        self.label_width = 1\n",
    "        self.shift = 1\n",
    "\n",
    "        self.total_window_size = self.input_width + self.shift\n",
    "\n",
    "        self.input_slice = slice(0, input_width)\n",
    "        self.input_indices = np.arange(self.total_window_size)[self.input_slice]\n",
    "\n",
    "        self.label_start = self.total_window_size - self.label_width\n",
    "        self.labels_slice = slice(self.label_start, None)\n",
    "        self.label_indices = np.arange(self.total_window_size)[self.labels_slice]\n",
    "\n",
    "        self.input_feature_slice = input_feature_slice\n",
    "        self.label_feature_slice = label_feature_slice\n",
    "\n",
    "        self.sample_ds = self.make_dataset(data[data[index] == 0])\n",
    "\n",
    "        if generate_data_now:\n",
    "            self.training_ds = self.make_dataset(self.training, index)\n",
    "            self.validation_ds = self.make_dataset(self.validation, index)\n",
    "            self.testing_ds = self.make_dataset(self.testing, index)\n",
    "\n",
    "\n",
    "    def __repr__(self):\n",
    "        return \"WindowGenerator:\\n\\t\" +'\\n\\t'.join([\n",
    "            f'Total window size: {self.total_window_size}',\n",
    "            f'Input indices: {self.input_indices}',\n",
    "            f'Label indices: {self.label_indices}',\n",
    "        ])\n",
    "\n",
    "    def split_window(self, features):\n",
    "        inputs = features[:, self.input_slice, self.input_feature_slice]\n",
    "        inputs.set_shape([None, self.input_width, None])\n",
    "        \n",
    "        labels = tf.squeeze(features[:, self.labels_slice, self.label_feature_slice])\n",
    "        # labels.set_shape([None, self.label_width, None])\n",
    "        return inputs, labels\n",
    "\n",
    "    def make_dataset(self, data, index_group: str = \"sleep_id\", sort_by: str = \"minutes_since_begin\"):\n",
    "        ds_all = None\n",
    "        for i_group in data[index_group].unique():\n",
    "            subset_data = np.array(data[data[index_group] == i_group].sort_values(by=[sort_by]), dtype=np.float32)\n",
    "            ds = tf.keras.utils.timeseries_dataset_from_array(\n",
    "                data=subset_data,\n",
    "                targets=None,\n",
    "                sequence_length=self.total_window_size,\n",
    "                sequence_stride=1,\n",
    "                shuffle=False,\n",
    "                batch_size=BATCH_SIZE,)\n",
    "            ds_all = ds if ds_all is None else ds_all.concatenate(ds)\n",
    "        ds_all = ds_all.map(self.split_window)\n",
    "\n",
    "        return ds_all"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### General Model Helper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Adapted from https://www.tensorflow.org/tutorials/structured_data/time_series#linear_model\n",
    "def compile_and_fit(model, window: WindowGenerator, loss = tf.losses.CategoricalCrossentropy(from_logits=True), optimizer = tf.optimizers.Adam(), metrics = None, early_stop: bool = True, patience:int = 2, baseline = None, epochs: int = MAX_EPOCHS):\n",
    "    if metrics is None:\n",
    "        metrics = [tf.keras.metrics.CategoricalCrossentropy(from_logits=True), tf.keras.metrics.CategoricalAccuracy(), tf.keras.metrics.CategoricalHinge()]\n",
    "\n",
    "    callbacks = []\n",
    "    if early_stop:\n",
    "        early_stopping = tf.keras.callbacks.EarlyStopping(\n",
    "            monitor='val_loss',\n",
    "            patience=patience,\n",
    "            baseline = baseline,\n",
    "            mode='min'\n",
    "        )\n",
    "        callbacks.append(early_stopping)\n",
    "\n",
    "    model.compile(\n",
    "        loss=loss,\n",
    "        optimizer=optimizer,\n",
    "        metrics=metrics,\n",
    "    )\n",
    "\n",
    "    return model.fit(window.training_ds, epochs=epochs, validation_data=window.validation_ds, callbacks=callbacks)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Experimenting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# USE SUBSET OF DATA FOR EXPERIMENTING\n",
    "sleep_data_sub = sleep_data[sleep_data.sleep_id < 3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "wg_sub = WindowGenerator(sleep_data_sub,validation_size=1, test_size=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ele[0][0] = array([[ 0.,  7., 40.,  1.,  0.,  0.,  0.],\n",
      "       [ 1.,  7., 41.,  1.,  0.,  0.,  0.],\n",
      "       [ 2.,  7., 42.,  1.,  0.,  0.,  0.],\n",
      "       [ 3.,  7., 43.,  1.,  0.,  0.,  0.],\n",
      "       [ 4.,  7., 44.,  1.,  0.,  0.,  0.],\n",
      "       [ 5.,  7., 45.,  1.,  0.,  0.,  0.],\n",
      "       [ 6.,  7., 46.,  1.,  0.,  0.,  0.],\n",
      "       [ 7.,  7., 47.,  1.,  0.,  0.,  0.],\n",
      "       [ 8.,  7., 48.,  1.,  0.,  0.,  0.],\n",
      "       [ 9.,  7., 49.,  1.,  0.,  0.,  0.]], dtype=float32)\n",
      "ele[0][0] = array([[64.,  8., 44.,  0.,  0.,  1.,  0.],\n",
      "       [65.,  8., 45.,  0.,  0.,  1.,  0.],\n",
      "       [66.,  8., 46.,  0.,  0.,  1.,  0.],\n",
      "       [67.,  8., 47.,  0.,  0.,  1.,  0.],\n",
      "       [68.,  8., 48.,  0.,  0.,  1.,  0.],\n",
      "       [69.,  8., 49.,  0.,  0.,  1.,  0.],\n",
      "       [70.,  8., 50.,  0.,  0.,  1.,  0.],\n",
      "       [71.,  8., 51.,  0.,  0.,  1.,  0.],\n",
      "       [72.,  8., 52.,  0.,  0.,  1.,  0.],\n",
      "       [73.,  8., 53.,  0.,  0.,  1.,  0.]], dtype=float32)\n",
      "ele[0][0] = array([[128.,   9.,  48.,   0.,   0.,   1.,   0.],\n",
      "       [129.,   9.,  49.,   0.,   0.,   1.,   0.],\n",
      "       [130.,   9.,  50.,   0.,   0.,   1.,   0.],\n",
      "       [131.,   9.,  51.,   0.,   0.,   1.,   0.],\n",
      "       [132.,   9.,  52.,   0.,   0.,   0.,   1.],\n",
      "       [133.,   9.,  53.,   0.,   0.,   0.,   1.],\n",
      "       [134.,   9.,  54.,   0.,   0.,   0.,   1.],\n",
      "       [135.,   9.,  55.,   0.,   0.,   0.,   1.],\n",
      "       [136.,   9.,  56.,   0.,   0.,   0.,   1.],\n",
      "       [137.,   9.,  57.,   0.,   0.,   0.,   1.]], dtype=float32)\n",
      "ele[0][0] = array([[192.,  10.,  52.,   0.,   1.,   0.,   0.],\n",
      "       [193.,  10.,  53.,   0.,   1.,   0.,   0.],\n",
      "       [194.,  10.,  54.,   0.,   1.,   0.,   0.],\n",
      "       [195.,  10.,  55.,   0.,   1.,   0.,   0.],\n",
      "       [196.,  10.,  56.,   0.,   1.,   0.,   0.],\n",
      "       [197.,  10.,  57.,   0.,   1.,   0.,   0.],\n",
      "       [198.,  10.,  58.,   0.,   1.,   0.,   0.],\n",
      "       [199.,  10.,  59.,   0.,   1.,   0.,   0.],\n",
      "       [200.,  11.,   0.,   0.,   1.,   0.,   0.],\n",
      "       [201.,  11.,   1.,   0.,   1.,   0.,   0.]], dtype=float32)\n",
      "ele[0][0] = array([[256.,  11.,  56.,   0.,   0.,   1.,   0.],\n",
      "       [257.,  11.,  57.,   0.,   0.,   1.,   0.],\n",
      "       [258.,  11.,  58.,   0.,   0.,   1.,   0.],\n",
      "       [259.,  11.,  59.,   0.,   0.,   1.,   0.],\n",
      "       [260.,  12.,   0.,   0.,   0.,   1.,   0.],\n",
      "       [261.,  12.,   1.,   0.,   0.,   1.,   0.],\n",
      "       [262.,  12.,   2.,   0.,   0.,   1.,   0.],\n",
      "       [263.,  12.,   3.,   0.,   0.,   1.,   0.],\n",
      "       [264.,  12.,   4.,   0.,   0.,   1.,   0.],\n",
      "       [265.,  12.,   5.,   0.,   0.,   1.,   0.]], dtype=float32)\n",
      "ele[0][0] = array([[320.,  13.,   0.,   0.,   1.,   0.,   0.],\n",
      "       [321.,  13.,   1.,   0.,   1.,   0.,   0.],\n",
      "       [322.,  13.,   2.,   0.,   1.,   0.,   0.],\n",
      "       [323.,  13.,   3.,   0.,   1.,   0.,   0.],\n",
      "       [324.,  13.,   4.,   0.,   1.,   0.,   0.],\n",
      "       [325.,  13.,   5.,   0.,   1.,   0.,   0.],\n",
      "       [326.,  13.,   6.,   0.,   1.,   0.,   0.],\n",
      "       [327.,  13.,   7.,   0.,   1.,   0.,   0.],\n",
      "       [328.,  13.,   8.,   0.,   1.,   0.,   0.],\n",
      "       [329.,  13.,   9.,   0.,   1.,   0.,   0.]], dtype=float32)\n",
      "ele[0][0] = array([[384.,  14.,   4.,   0.,   0.,   0.,   1.],\n",
      "       [385.,  14.,   5.,   0.,   0.,   0.,   1.],\n",
      "       [386.,  14.,   6.,   0.,   0.,   0.,   1.],\n",
      "       [387.,  14.,   7.,   0.,   0.,   0.,   1.],\n",
      "       [388.,  14.,   8.,   0.,   0.,   0.,   1.],\n",
      "       [389.,  14.,   9.,   0.,   0.,   0.,   1.],\n",
      "       [390.,  14.,  10.,   0.,   0.,   0.,   1.],\n",
      "       [391.,  14.,  11.,   0.,   0.,   0.,   1.],\n",
      "       [392.,  14.,  12.,   0.,   0.,   0.,   1.],\n",
      "       [393.,  14.,  13.,   0.,   0.,   0.,   1.]], dtype=float32)\n",
      "ele[0][0] = array([[448.,  15.,   8.,   0.,   1.,   0.,   0.],\n",
      "       [449.,  15.,   9.,   0.,   1.,   0.,   0.],\n",
      "       [450.,  15.,  10.,   0.,   1.,   0.,   0.],\n",
      "       [451.,  15.,  11.,   0.,   1.,   0.,   0.],\n",
      "       [452.,  15.,  12.,   0.,   1.,   0.,   0.],\n",
      "       [453.,  15.,  13.,   0.,   1.,   0.,   0.],\n",
      "       [454.,  15.,  14.,   0.,   1.,   0.,   0.],\n",
      "       [455.,  15.,  15.,   0.,   1.,   0.,   0.],\n",
      "       [456.,  15.,  16.,   0.,   1.,   0.,   0.],\n",
      "       [457.,  15.,  17.,   0.,   1.,   0.,   0.]], dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "for ele in wg_sub.training_ds.as_numpy_iterator():\n",
    "    print(f\"{ele[0][0] = }\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_2\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " dense_4 (Dense)             (None, 10, 16)            128       \n",
      "                                                                 \n",
      " flatten_2 (Flatten)         (None, 160)               0         \n",
      "                                                                 \n",
      " dense_5 (Dense)             (None, 4)                 644       \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 772\n",
      "Trainable params: 772\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "BASELINE_UNITS = 16\n",
    "baseline_model = keras.Sequential(\n",
    "    [\n",
    "        layers.Input(shape=(INPUT_TIME_STEP, INPUT_FEATURES_SIZE)),\n",
    "        layers.Dense(BASELINE_UNITS),\n",
    "        # layers.Dense(BASELINE_UNITS),\n",
    "        # layers.Dense(BASELINE_UNITS),\n",
    "        layers.Flatten(),\n",
    "        layers.Dense(SLEEP_STAGES),\n",
    "    ]\n",
    ")\n",
    "baseline_model.build()\n",
    "print(baseline_model.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20\n",
      "8955/8955 [==============================] - 60s 7ms/step - loss: 2.7762 - categorical_crossentropy: 2.7745 - categorical_accuracy: 0.7612 - categorical_hinge: 15.1798 - val_loss: 0.2711 - val_categorical_crossentropy: 0.2710 - val_categorical_accuracy: 0.9546 - val_categorical_hinge: 5.1404\n",
      "Epoch 2/20\n",
      "8955/8955 [==============================] - 60s 7ms/step - loss: 0.3110 - categorical_crossentropy: 0.3111 - categorical_accuracy: 0.9418 - categorical_hinge: 1.5468 - val_loss: 0.2106 - val_categorical_crossentropy: 0.2105 - val_categorical_accuracy: 0.9583 - val_categorical_hinge: 0.3215\n",
      "Epoch 3/20\n",
      "8955/8955 [==============================] - 60s 7ms/step - loss: 0.2253 - categorical_crossentropy: 0.2254 - categorical_accuracy: 0.9551 - categorical_hinge: 0.2726 - val_loss: 0.2068 - val_categorical_crossentropy: 0.2067 - val_categorical_accuracy: 0.9582 - val_categorical_hinge: 0.3656\n",
      "Epoch 4/20\n",
      "8955/8955 [==============================] - 63s 7ms/step - loss: 0.2059 - categorical_crossentropy: 0.2059 - categorical_accuracy: 0.9582 - categorical_hinge: 0.2735 - val_loss: 0.1957 - val_categorical_crossentropy: 0.1957 - val_categorical_accuracy: 0.9591 - val_categorical_hinge: 0.3622\n",
      "Epoch 5/20\n",
      "8955/8955 [==============================] - 82s 9ms/step - loss: 0.2016 - categorical_crossentropy: 0.2016 - categorical_accuracy: 0.9587 - categorical_hinge: 0.2638 - val_loss: 0.1970 - val_categorical_crossentropy: 0.1970 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.2868\n",
      "Epoch 6/20\n",
      "8955/8955 [==============================] - 72s 8ms/step - loss: 0.1995 - categorical_crossentropy: 0.1995 - categorical_accuracy: 0.9585 - categorical_hinge: 0.2506 - val_loss: 0.1973 - val_categorical_crossentropy: 0.1973 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.3453\n",
      "Epoch 7/20\n",
      "8955/8955 [==============================] - 75s 8ms/step - loss: 0.1991 - categorical_crossentropy: 0.1992 - categorical_accuracy: 0.9591 - categorical_hinge: 0.2518 - val_loss: 0.1969 - val_categorical_crossentropy: 0.1968 - val_categorical_accuracy: 0.9591 - val_categorical_hinge: 0.2474\n",
      "Epoch 8/20\n",
      "8955/8955 [==============================] - 75s 8ms/step - loss: 0.1983 - categorical_crossentropy: 0.1983 - categorical_accuracy: 0.9587 - categorical_hinge: 0.2560 - val_loss: 0.1957 - val_categorical_crossentropy: 0.1956 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.2912\n",
      "Epoch 9/20\n",
      "8955/8955 [==============================] - 76s 8ms/step - loss: 0.1965 - categorical_crossentropy: 0.1965 - categorical_accuracy: 0.9595 - categorical_hinge: 0.2328 - val_loss: 0.1957 - val_categorical_crossentropy: 0.1957 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.3268\n",
      "Epoch 10/20\n",
      "8955/8955 [==============================] - 77s 9ms/step - loss: 0.1961 - categorical_crossentropy: 0.1961 - categorical_accuracy: 0.9598 - categorical_hinge: 0.2375 - val_loss: 0.1954 - val_categorical_crossentropy: 0.1954 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.2518\n",
      "Epoch 11/20\n",
      "8955/8955 [==============================] - 72s 8ms/step - loss: 0.1980 - categorical_crossentropy: 0.1981 - categorical_accuracy: 0.9584 - categorical_hinge: 0.2387 - val_loss: 0.1948 - val_categorical_crossentropy: 0.1948 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.2795\n",
      "Epoch 12/20\n",
      "8955/8955 [==============================] - 69s 8ms/step - loss: 0.1959 - categorical_crossentropy: 0.1959 - categorical_accuracy: 0.9594 - categorical_hinge: 0.2465 - val_loss: 0.1950 - val_categorical_crossentropy: 0.1950 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.2824\n",
      "Epoch 13/20\n",
      "8955/8955 [==============================] - 65s 7ms/step - loss: 0.1960 - categorical_crossentropy: 0.1960 - categorical_accuracy: 0.9596 - categorical_hinge: 0.2413 - val_loss: 0.1955 - val_categorical_crossentropy: 0.1954 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.3859\n",
      "Epoch 14/20\n",
      "8955/8955 [==============================] - 68s 8ms/step - loss: 0.1959 - categorical_crossentropy: 0.1959 - categorical_accuracy: 0.9591 - categorical_hinge: 0.2280 - val_loss: 0.1941 - val_categorical_crossentropy: 0.1941 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.3046\n",
      "Epoch 15/20\n",
      "8955/8955 [==============================] - 77s 9ms/step - loss: 0.2031 - categorical_crossentropy: 0.2031 - categorical_accuracy: 0.9580 - categorical_hinge: 0.2336 - val_loss: 0.1949 - val_categorical_crossentropy: 0.1948 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.1873\n",
      "Epoch 16/20\n",
      "8955/8955 [==============================] - 68s 8ms/step - loss: 0.1966 - categorical_crossentropy: 0.1966 - categorical_accuracy: 0.9589 - categorical_hinge: 0.2361 - val_loss: 0.1938 - val_categorical_crossentropy: 0.1938 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.2056\n",
      "Epoch 17/20\n",
      "8955/8955 [==============================] - 67s 7ms/step - loss: 0.1924 - categorical_crossentropy: 0.1924 - categorical_accuracy: 0.9599 - categorical_hinge: 0.2510 - val_loss: 0.1934 - val_categorical_crossentropy: 0.1934 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.3111\n",
      "Epoch 18/20\n",
      "8955/8955 [==============================] - 67s 7ms/step - loss: 0.2079 - categorical_crossentropy: 0.2079 - categorical_accuracy: 0.9569 - categorical_hinge: 0.3273 - val_loss: 0.1930 - val_categorical_crossentropy: 0.1929 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.2956\n",
      "Epoch 19/20\n",
      "8955/8955 [==============================] - 67s 7ms/step - loss: 0.1923 - categorical_crossentropy: 0.1923 - categorical_accuracy: 0.9599 - categorical_hinge: 0.3248 - val_loss: 0.1924 - val_categorical_crossentropy: 0.1924 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.6417\n",
      "Epoch 20/20\n",
      "8955/8955 [==============================] - 68s 8ms/step - loss: 0.1950 - categorical_crossentropy: 0.1950 - categorical_accuracy: 0.9595 - categorical_hinge: 0.3199 - val_loss: 0.1935 - val_categorical_crossentropy: 0.1935 - val_categorical_accuracy: 0.9592 - val_categorical_hinge: 0.5098\n"
     ]
    }
   ],
   "source": [
    "baseline_history = compile_and_fit(baseline_model, wg_sub)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['loss', 'categorical_crossentropy', 'categorical_accuracy', 'categorical_hinge', 'val_loss', 'val_categorical_crossentropy', 'val_categorical_accuracy', 'val_categorical_hinge'])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "baseline_history.history.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_history(baseline_history.history, \".history/baseline.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_history(history, file_name: str = \"history.csv\"):\n",
    "    pd.DataFrame.from_dict(history).to_csv(file_name, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-05-11 15:02:03.185843: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: .model/baseline/assets\n"
     ]
    }
   ],
   "source": [
    "baseline_model.save(\".model/baseline\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Prep\n",
    "\n",
    "All inputs follow: (batch_size, timesteps, input_dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-05-12 14:18:02.248350: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "WindowGenerator:\n",
       "\tTotal window size: 6\n",
       "\tInput indices: [0 1 2 3 4]\n",
       "\tLabel indices: [5]"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wg = WindowGenerator(sleep_data)\n",
    "wg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4682"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(wg.training_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ele[0].numpy()[0] = array([[ 0.   ,  8.   , 18.   ,  0.68 ,  0.057,  0.2  ,  0.063],\n",
      "       [ 1.   ,  8.   , 19.   ,  0.652,  0.06 ,  0.224,  0.064],\n",
      "       [ 2.   ,  8.   , 20.   ,  0.672,  0.059,  0.209,  0.06 ],\n",
      "       [ 3.   ,  8.   , 21.   ,  0.645,  0.056,  0.235,  0.064],\n",
      "       [ 4.   ,  8.   , 22.   ,  0.644,  0.054,  0.244,  0.058]],\n",
      "      dtype=float32)\n",
      "ele[1].numpy()[0] = array([1., 0., 0., 0.], dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "for ele in wg.sample_ds.take(1):\n",
    "    print(f\"{ele[0].numpy()[0] = }\")\n",
    "    print(f\"{ele[1].numpy()[0] = }\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model 1: LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyper-parameters\n",
    "LSTM_UNITS = 8\n",
    "LSTM_LEARNING_RATE = 0.0001"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_2\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " lstm (LSTM)                 (None, 8)                 512       \n",
      "                                                                 \n",
      " dense_2 (Dense)             (None, 4)                 36        \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 548\n",
      "Trainable params: 548\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "# Model Definition\n",
    "lstm_model = keras.Sequential()\n",
    "lstm_model.add(layers.Input(shape=(INPUT_TIME_STEP, INPUT_FEATURES_SIZE)))\n",
    "# lstm_model.add(layers.LSTM(LSTM_UNITS, stateful=False, return_sequences=True))\n",
    "# lstm_model.add(layers.LSTM(LSTM_UNITS, stateful=False, return_sequences=True))\n",
    "lstm_model.add(layers.LSTM(LSTM_UNITS, stateful=False, return_sequences=False))\n",
    "lstm_model.add(layers.Dense(SLEEP_STAGES))\n",
    "lstm_model.build()\n",
    "print(lstm_model.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model Training\n",
    "lstm_loss = tf.losses.CategoricalCrossentropy(from_logits=True)\n",
    "lstm_optm = tf.optimizers.Adam(learning_rate=LSTM_LEARNING_RATE)\n",
    "lstm_metrics = [tf.keras.metrics.CategoricalCrossentropy(from_logits=True), tf.keras.metrics.BinaryCrossentropy(from_logits=True),tf.keras.metrics.Accuracy()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9647/9647 [==============================] - 88s 9ms/step - loss: 0.6389 - categorical_crossentropy: 0.6388 - categorical_accuracy: 0.7669 - categorical_hinge: 0.5620 - val_loss: 0.2218 - val_categorical_crossentropy: 0.2217 - val_categorical_accuracy: 0.9551 - val_categorical_hinge: 0.2163\n"
     ]
    }
   ],
   "source": [
    "lstm_history = compile_and_fit(model=lstm_model, window=wg, epochs=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model 2: GRU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyper-paramters\n",
    "GRU_UNITS = 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_1\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " gru (GRU)                   (None, 16)                1200      \n",
      "                                                                 \n",
      " dense_1 (Dense)             (None, 4)                 68        \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 1,268\n",
      "Trainable params: 1,268\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "gru_model = keras.Sequential()\n",
    "gru_model.add(layers.Input(shape=(INPUT_TIME_STEP, INPUT_FEATURES_SIZE)))\n",
    "gru_model.add(layers.GRU(GRU_UNITS))\n",
    "gru_model.add(layers.Dense(SLEEP_STAGES))\n",
    "gru_model.build()\n",
    "print(gru_model.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9647/9647 [==============================] - 97s 10ms/step - loss: 0.4700 - categorical_crossentropy: 0.4700 - categorical_accuracy: 0.8416 - categorical_hinge: 0.4278 - val_loss: 0.2088 - val_categorical_crossentropy: 0.2087 - val_categorical_accuracy: 0.9595 - val_categorical_hinge: 0.1954\n"
     ]
    }
   ],
   "source": [
    "gru_history = compile_and_fit(model=gru_model, window=wg, epochs=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model 3: Attention Mechanism"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "ATTENTION_UNITS = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " custom_attention (CustomAtt  (None, 32)               497       \n",
      " ention)                                                         \n",
      "                                                                 \n",
      " dense (Dense)               (None, 4)                 132       \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 629\n",
      "Trainable params: 629\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "am_model = keras.Sequential()\n",
    "am_model.add(layers.Input(shape=(INPUT_TIME_STEP, INPUT_FEATURES_SIZE)))\n",
    "am_model.add(CustomAttention(ATTENTION_UNITS))\n",
    "am_model.add(layers.Dense(SLEEP_STAGES))\n",
    "am_model.build()\n",
    "print(am_model.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x12aaebe50>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "am_model.load_weights(\"TEST\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "am_history = compile_and_fit(model=am_model, window=wg, epochs=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1442/1442 [==============================] - 14s 4ms/step - loss: 0.1957 - categorical_crossentropy: 0.1957 - categorical_accuracy: 0.9598 - categorical_hinge: 0.1911\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.19572481513023376,\n",
       " 0.19574123620986938,\n",
       " 0.9597873687744141,\n",
       " 0.1910863220691681]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "am_model.evaluate(wg.testing_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.59796184, -0.20803624, -0.2916372 , -0.7170148 ]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "answer = am_model.predict(np.array(\n",
    "    [[\n",
    "        [ 150.,  10., 20.,  0.5,  0.1,  0.,  0.4],\n",
    "        [ 151.,  10., 21.,  0.5,  0.2,  0.,  0.3],\n",
    "        [ 152.,  10., 22.,  0.5,  0.3,  0.,  0.2],\n",
    "        [ 153.,  10., 23.,  0.5,  0.4,  0.,  0.1],\n",
    "        [ 154.,  10., 24.,  0.2,  0.1,  0.,  0.],\n",
    "        [ 155.,  10., 25.,  0.2,  0.2,  0.,  0.],\n",
    "        [ 156.,  10., 26.,  0.2,  0.2,  0.1,  0.],\n",
    "        [ 157.,  10., 27.,  0.2,  0.25,  0.15,  0.],\n",
    "        [ 158.,  10., 28.,  0.2,  0.3,  0.2,  0.],\n",
    "        [ 159.,  10., 29.,  0.2,  0.3,  0.3,  0.]\n",
    "    ]]\n",
    "))\n",
    "norm = np.linalg.norm(answer)\n",
    "normalized_answer = answer/norm\n",
    "normalized_answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame.from_dict(am_history.history).to_csv(f\"{HISTORY_DIR}/am_220512.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Help on method save_weights in module keras.engine.training:\n",
      "\n",
      "save_weights(filepath, overwrite=True, save_format=None, options=None) method of keras.engine.sequential.Sequential instance\n",
      "    Saves all layer weights.\n",
      "    \n",
      "    Either saves in HDF5 or in TensorFlow format based on the `save_format`\n",
      "    argument.\n",
      "    \n",
      "    When saving in HDF5 format, the weight file has:\n",
      "      - `layer_names` (attribute), a list of strings\n",
      "          (ordered names of model layers).\n",
      "      - For every layer, a `group` named `layer.name`\n",
      "          - For every such layer group, a group attribute `weight_names`,\n",
      "              a list of strings\n",
      "              (ordered names of weights tensor of the layer).\n",
      "          - For every weight in the layer, a dataset\n",
      "              storing the weight value, named after the weight tensor.\n",
      "    \n",
      "    When saving in TensorFlow format, all objects referenced by the network are\n",
      "    saved in the same format as `tf.train.Checkpoint`, including any `Layer`\n",
      "    instances or `Optimizer` instances assigned to object attributes. For\n",
      "    networks constructed from inputs and outputs using `tf.keras.Model(inputs,\n",
      "    outputs)`, `Layer` instances used by the network are tracked/saved\n",
      "    automatically. For user-defined classes which inherit from `tf.keras.Model`,\n",
      "    `Layer` instances must be assigned to object attributes, typically in the\n",
      "    constructor. See the documentation of `tf.train.Checkpoint` and\n",
      "    `tf.keras.Model` for details.\n",
      "    \n",
      "    While the formats are the same, do not mix `save_weights` and\n",
      "    `tf.train.Checkpoint`. Checkpoints saved by `Model.save_weights` should be\n",
      "    loaded using `Model.load_weights`. Checkpoints saved using\n",
      "    `tf.train.Checkpoint.save` should be restored using the corresponding\n",
      "    `tf.train.Checkpoint.restore`. Prefer `tf.train.Checkpoint` over\n",
      "    `save_weights` for training checkpoints.\n",
      "    \n",
      "    The TensorFlow format matches objects and variables by starting at a root\n",
      "    object, `self` for `save_weights`, and greedily matching attribute\n",
      "    names. For `Model.save` this is the `Model`, and for `Checkpoint.save` this\n",
      "    is the `Checkpoint` even if the `Checkpoint` has a model attached. This\n",
      "    means saving a `tf.keras.Model` using `save_weights` and loading into a\n",
      "    `tf.train.Checkpoint` with a `Model` attached (or vice versa) will not match\n",
      "    the `Model`'s variables. See the\n",
      "    [guide to training checkpoints](https://www.tensorflow.org/guide/checkpoint)\n",
      "    for details on the TensorFlow format.\n",
      "    \n",
      "    Args:\n",
      "        filepath: String or PathLike, path to the file to save the weights to.\n",
      "            When saving in TensorFlow format, this is the prefix used for\n",
      "            checkpoint files (multiple files are generated). Note that the '.h5'\n",
      "            suffix causes weights to be saved in HDF5 format.\n",
      "        overwrite: Whether to silently overwrite any existing file at the\n",
      "            target location, or provide the user with a manual prompt.\n",
      "        save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or\n",
      "            '.keras' will default to HDF5 if `save_format` is `None`. Otherwise\n",
      "            `None` defaults to 'tf'.\n",
      "        options: Optional `tf.train.CheckpointOptions` object that specifies\n",
      "            options for saving weights.\n",
      "    \n",
      "    Raises:\n",
      "        ImportError: If `h5py` is not available when attempting to save in HDF5\n",
      "            format.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "help(am_model.save_weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-05-12 12:12:02.267722: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.\n",
      "WARNING:absl:Found untraced functions such as attention_score_vec_layer_call_fn, attention_score_vec_layer_call_and_return_conditional_losses, last_hidden_state_layer_call_fn, last_hidden_state_layer_call_and_return_conditional_losses, attention_score_layer_call_fn while saving (showing 5 of 14). These functions will not be directly callable after loading.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: .model/attention/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: .model/attention/assets\n"
     ]
    }
   ],
   "source": [
    "am_model.save(f\"{MODEL_DIR}/attention\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model Head-to-Head testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "HISTORY_DIR = \".history\"\n",
    "MODEL_DIR = \".model\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_history = pd.read_csv(f\"{HISTORY_DIR}/baseline.csv\")\n",
    "lstm_history = pd.read_csv(f\"{HISTORY_DIR}/lstm.csv\")\n",
    "gru_history = pd.read_csv(f\"{HISTORY_DIR}/gru.csv\")\n",
    "attn_history = pd.read_csv(f\"{HISTORY_DIR}/attention.csv\")\n",
    "history_columns = [\"categorical_crossentropy\", \"val_categorical_crossentropy\"] # \"categorical_accuracy\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<AxesSubplot:>"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAtAElEQVR4nO3deXRU9f3/8ed71uwQEvZFtOWromyC4I7gglsBbQGrLSJSj1XUalvFHy60tUerllqtR6pVEMWqoCBtaWv9CrVUoSYUkUUR+CKGNQmQkIUsM+/fH7MwhGyQZZI778c5c2bunTv3vhkmr/nM5977uaKqGGOMaf9c8S7AGGNM87BAN8YYh7BAN8YYh7BAN8YYh7BAN8YYh/DEa8PZ2dnat2/feG3eGGPapdzc3AJV7Vzbc3EL9L59+5KTkxOvzRtjTLskIl/V9Zx1uRhjjENYoBtjjENYoBtjjEPErQ/dmOZWVVVFXl4ehw8fjncpxjRZUlISvXr1wuv1Nvo1FujGMfLy8khPT6dv376ISLzLMeaEqSqFhYXk5eVx8sknN/p11uViHOPw4cNkZWVZmJt2T0TIyso67l+bFujGUSzMjVOcyGe53QX6J9v386u/fY4N+2uMMUdrd4H+6dcHeX7FVorLq+NdijHGtCntLtCz0/wAFJRWxLkSY5pmxYoVfPTRR62yrauuuoqDBw8e9+vmzZvH9OnTm7+gFrZ9+3Zef/31eJfR6tpdoGel+QAoLKmMcyXGNE1rBLqqEgwGWbZsGR07dmzRbdW3/dZWX6BXVzv31327O2wxKzXUQi8ssRa6qdvP/rSBjbuKm3Wd/Xtk8Mi3zmhwufnz5/PUU08hIgwcOJCJEyfy6KOPUllZSVZWFgsWLKC8vJw5c+bgdrt57bXXePbZZznttNO47bbb2LFjBwBPP/00559/Pvn5+dxwww3s2rWLc889l3/84x/k5uaSnZ3N7NmzefnllwGYNm0aP/rRj9i+fTtjxoxhxIgR5ObmsmzZMkaOHElOTg7Z2dnH1Pfqq6/ypz/96Zgau3bt2uC/de/evdx2221s27YNgOeff54ePXocs/3f/e53/PWvf0VEePDBB5k0aRK7d+9m0qRJFBcXU11dzfPPP895553HLbfcQk5ODiLC1KlTueeee9i6dSt33HEH+fn5pKSk8OKLL3LaaacxZcoUMjIyyMnJYc+ePTzxxBN85zvfYcaMGWzatInBgwdz0003kZmZyTvvvENJSQmBQIDFixczdepUtm3bRkpKCi+88AIDBw5k1qxZbN26lS1btlBQUMB9993HD37wAyZPnsx1113H+PHjAbjxxhuZOHEi48aNO8FPU8tod4GeHW6hF5RaC920PRs2bODRRx/lo48+Ijs7m/379yMirFq1ChHhD3/4A0888QS//vWvue2220hLS+MnP/kJADfccAP33HMPF1xwATt27GDMmDFs2rSJn/3sZ4wePZoHHniAv/3tb7z00ksA5ObmMnfuXFavXo2qMmLECEaOHElmZiZffvklr7zyCuecc06D9QFccMEFtdbYkLvuuouRI0eyePFiAoEAJSUlHDhw4Kjtv/3226xdu5ZPP/2UgoICzj77bC666CJef/11xowZw8yZMwkEApSVlbF27Vp27tzJ+vXrAaLdRLfeeitz5syhX79+rF69mttvv50PPvgAgN27d7Ny5Uo+//xzxo4dy3e+8x0ef/xxnnrqKf785z8Doa6jNWvWsG7dOjp16sSdd97JkCFDWLJkCR988AGTJ09m7dq1AKxbt45Vq1ZRWlrKkCFDuPrqq7nlllv4zW9+w/jx4ykqKuKjjz7ilVdeadqHpQW0u0DPTI10uVgL3dStMS3plvDBBx8wYcIEsrOzAejUqROfffZZtEVaWVlZ54ki77//Phs3boxOFxcXU1JSwsqVK1m8eDEAV1xxBZmZmQCsXLmSa6+9ltTUVACuu+46/vWvfzF27FhOOumkY8K8rvogdFJWY2qsbX3z588HwO1206FDBw4cOHDU9leuXMl3v/td3G43Xbt2ZeTIkXzyySecffbZTJ06laqqKsaPH8/gwYM55ZRT2LZtG3feeSdXX301l19+OSUlJXz00UdMmDAhut2KiiN//+PHj8flctG/f3/27t1bZ62XXXZZ9N+7cuVK3n77bQBGjx5NYWEhxcWhX3Tjxo0jOTmZ5ORkRo0axX/+8x/Gjx/P7bffTn5+Pm+//Tbf/va38XjaXny2uz50r9tFxxSv9aGbduPOO+9k+vTpfPbZZ/z+97+v82SRYDDIqlWrWLt2bbSlmpaWdkLbjIR8c9fYnNu/6KKL+PDDD+nZsydTpkxh/vz5ZGZm8umnn3LxxRczZ84cpk2bRjAYpGPHjtH3Ze3atWzatCm6Hr/fH31c3+HMjX1Pah7/HZmePHkyr732GnPnzmXq1KmNWldra3eBDpCV6qPQjnIxbdDo0aNZuHAhhYWFAOzfv5+ioiJ69uwJcNTP9PT0dA4dOhSdvvzyy3n22Wej05EugPPPP5+33noLgPfee48DBw4AcOGFF7JkyRLKysooLS1l8eLFXHjhhcddH1BnjQ255JJLeP755wEIBAIUFRUds8yFF17Im2++SSAQID8/nw8//JDhw4fz1Vdf0bVrV37wgx8wbdo01qxZQ0FBAcFgkG9/+9s8+uijrFmzhoyMDE4++WQWLlwIhEL7008/rbeumu9tbTUtWLAACO2czs7OJiMjA4B3332Xw4cPU1hYyIoVKzj77LMBmDJlCk8//TQA/fv3b/R71JraZ6Cn+SmwFrppg8444wxmzpzJyJEjGTRoEPfeey+zZs1iwoQJDB06NNrVAfCtb32LxYsXM3jwYP71r3/xzDPPkJOTw8CBA+nfvz9z5swB4JFHHuG9997jzDPPZOHChXTr1o309HTOOusspkyZwvDhwxkxYgTTpk1jyJAhx10fUGeNDfntb3/L8uXLGTBgAEOHDj2qyyji2muvZeDAgQwaNIjRo0fzxBNP0K1bN1asWMGgQYMYMmQIb775JnfffTc7d+7k4osvZvDgwXzve9/jscceA2DBggW89NJLDBo0iDPOOIN333233roGDhyI2+1m0KBB/OY3vznm+VmzZpGbm8vAgQOZMWPGUV9iAwcOZNSoUZxzzjk89NBD9OjRA4CuXbty+umnc/PNNzf6/WltEq8zLocNG6YnesWi2xfksnlvCe/fO7KZqzLt2aZNmzj99NPjXUazq6iowO124/F4+Pjjj/nhD38Ybb2b5jVr1qyjdlTHKisrY8CAAaxZs4YOHTq0Sj21faZFJFdVh9W2fNvr1W+ErFQ/hSWF8S7DmFaxY8cOJk6cSDAYxOfz8eKLL8a7pITz/vvvc8stt3DPPfe0WpifiPYZ6Gk+DpRVUR0I4nG3y14jYxqtX79+/Pe//41rDb/85S+jfdgREyZMYObMmXGqqGXMmjWr1vmXXnopX31V56U824x2Guihvdr7yyrpkp4U52qMcb6ZM2c6LrydqF02b7NT7fR/Y4ypqV0GeicLdGOMOUa7DPRIl4sdi26MMUe0y0CPjudiLXRjjIlql4GekeTF4xIbz8W0eyd6an9tlixZUuuJPS3hvPPOO6HXzZo1i6eeeqqZq2l5a9euZdmyZfEuo0HtMtBdLqFTqs/60I2J0RqBHhlLvLUuzFHX9ltbfYHelsZXb5eHLUKoH9360E2d/joD9nzWvOvsNgCufLzeRWbMmEHv3r254447gFCL1OPxsHz5cg4cOEBVVRWPPvpoo8fR/tWvfsVrr72Gy+Xiyiuv5PHHH+fFF1/khRdeoLKykm9+85u8+uqrrF27lqVLl/LPf/6TRx99NDqSYG1jiG/dupUbb7yR0tJSxo0bx9NPP01JSQmqyn333XfMuOUrVqzgoYceIjMzk88//5zNmzeTlpZGSUnJcdWYkpLS4L93y5Yt3HbbbeTn5+N2u1m4cCFff/31Udtft24dP/zhD8nJycHj8TB79mxGjRrFhg0buPnmm6msrCQYDPL222/To0cPJk6cSF5eHoFAgIceeohJkyaRm5vLvffeS0lJCdnZ2cybN4/u3btz8cUXM2LECJYvX87Bgwd56aWXGDFiBA8//DDl5eWsXLmSBx54gE2bNrF161a2bdtGnz59eOyxx5g6dSoFBQV07tyZuXPn0qdPH6ZMmUJSUhI5OTkUFxcze/ZsrrnmGi666CKeeeYZBg8eDISGL37uuecYNGhQoz4XdVLVuNyGDh2qTfG9P6zScb9b2aR1GGfZuHHjkYll96u+fFXz3pbd32ANa9as0Ysuuig6ffrpp+uOHTu0qKhIVVXz8/P1G9/4hgaDQVVVTU1NrXNdy5Yt03PPPVdLS0tVVbWwsFBVVQsKCqLLzJw5U5955hlVVb3pppt04cKF0edGjx6tmzdvVlXVVatW6ahRo1RV9eqrr9bXX39dVVWff/75aA2LFi3SSy+9VKurq3XPnj3au3dv3bVrly5fvlxTUlJ027Zt0XVHXnO8NT7yyCP65JNP1vlvHj58uL7zzjuqqlpeXq6lpaXHbP+pp57Sm2++WVVVN23apL1799by8nKdPn26vvbaa6qqWlFRoWVlZbpo0SKdNm1adP0HDx7UyspKPffcc3Xfvn2qqvrGG29E1zdy5Ei99957VVX1L3/5i15yySWqqjp37ly94447out55JFH9KyzztKysjJVVb3mmmt03rx5qqr60ksv6bhx46L/J2PGjNFAIKCbN2/Wnj17anl5uc6bN0/vvvtuVVX94osvtK48POozHQbkaB252n5b6Kk+theWxrsM01Y10JJuKUOGDGHfvn3s2rWL/Px8MjMz6datG/fccw8ffvghLpeLnTt3snfvXrp161bvut5//31uvvnmaMs2Mpb3+vXrefDBBzl48CAlJSWMGTPmmNfWN4b4xx9/zJIlS4DQRTUi45bUNW55RkYGw4cPr3WM9KbUWNOhQ4fYuXMn1157LQBJSUdOGozd/sqVK7nzzjsBOO200zjppJPYvHkz5557Lr/85S/Jy8vjuuuuo1+/fgwYMIAf//jH3H///VxzzTVceOGFrF+/nvXr13PZZZcBoVEiu3fvHt3WddddB8DQoUPZvn17nfWOHTuW5OTk6Hv6zjvvAPD973+f++67L7rcxIkTcblc9OvXj1NOOYXPP/+cCRMm8Itf/IInn3ySl19+mSlTpjT4/jRG+w30NL/1oZs2acKECSxatIg9e/YwadIkFixYQH5+Prm5uXi9Xvr27duk8canTJnCkiVLGDRoEPPmzWPFihXHLBM7hnhzON7x1RtTY3Nv/4YbbmDEiBH85S9/4aqrruL3v/89o0ePZs2aNSxbtowHH3yQSy65hGuvvZYzzjiDjz/+uNb1RMZXd7vd9faPN2V89ZSUFC677DLeffdd3nrrLXJzcxu1roY0aqeoiFwhIl+IyBYRmVHHMhNFZKOIbBCRFr/cdlaaj7LKAGWVbWeHhDEAkyZN4o033mDRokVMmDCBoqIiunTpgtfrZfny5Y0eE+Syyy5j7ty5lJWVAUfGLj906BDdu3enqqoqOqY3HD0GeH1jiEcuCwfwxhtvRF9f17jlzVljfdLT0+nVq1f010NFRUV0vbFixzLfvHkzO3bs4NRTT2Xbtm2ccsop3HXXXYwbN45169axa9cuUlJS+N73vsdPf/pT1qxZw6mnnkp+fn400KuqqtiwYUODtdU3vvp5550XfS8XLFhw1Lj0CxcuJBgMRvvcTz31VCB0Ddi77rqLs88+O3oVqqZqMNBFxA08B1wJ9Ae+KyL9ayzTD3gAOF9VzwB+1CzV1SM7erFoa6WbtuWMM87g0KFD9OzZk+7du3PjjTeSk5PDgAEDmD9/Pqeddlqj1nPFFVcwduxYhg0bxuDBg6OH+/3iF79gxIgRnH/++Uet6/rrr+fJJ59kyJAhbN26tc4xxJ9++mlmz57NwIED2bJlS3T0wLrGLW/OGhvy6quv8swzzzBw4EDOO+889uzZc8wyt99+O8FgkAEDBjBp0iTmzZuH3+/nrbfe4swzz2Tw4MGsX7+eyZMn89lnnzF8+HAGDx7Mz372Mx588EF8Ph+LFi3i/vvvZ9CgQQwePLjBo3ZGjRrFxo0bGTx4MG+++eYxzz/77LPMnTs3etHt3/72t9Hn+vTpw/Dhw7nyyiuZM2dOtCtp6NChZGRkNO/46nV1rkduwLnA32OmHwAeqLHME8C0htalzbhT9P2Ne/Sk+/+s/91xoEnrMc5R2w4kc6zS0tLoTtk//vGPOnbs2DhX5Fw1d1TH2rlzp/br108DgUCdr2+JnaI9ga9jpvOAETWW+R8AEfk34AZmqerfaq5IRG4FboXQt1ZTRE//t5OLjDkuubm5TJ8+HVWlY8eOvPzyy/EuKeHMnz+fmTNnMnv2bFyu5jsdqLl2inqAfsDFQC/gQxEZoKoHYxdS1ReAFyB0xaKmbDDLBugyDvHZZ5/x/e9//6h5fr+f1atXt8j2LrzwwgavydnS7rjjDv79738fNe/uu+9u05d3OxHz5s2rdf7kyZOZPHlys2+vMYG+E+gdM90rPC9WHrBaVauA/xORzYQC/pNmqbIWWeHxXApLLdDNEap6zFEFbd2AAQMS7pJyzz33XLxLaPP0BC4P2pi2/idAPxE5WUR8wPXA0hrLLCHUOkdEsgl1wWw77mqOQ4rPQ4rPbV0uJiopKYnCwsIT+kMwpi1RVQoLC486Fr8xGmyhq2q1iEwH/k6of/xlVd0gIj8n1Dm/NPzc5SKyEQgAP1XVFr/oZ1aaz1roJqpXr17k5eWRn58f71KMabKkpCR69ep1XK9pVB+6qi4DltWY93DMYwXuDd9aTVaqnwJroZswr9db69mMxiSKdjnaYkR2mo24aIwxEe060LNSbcRFY4yJaNeB3incQredYMYY084DPSvVR3VQKS638VyMMaZdB3p2+GzRAut2McaY9h3o0ZOLbMeoMca080BPtfFcjDEmol0Hena4hV5gJxcZY0z7DvTM6ABd1kI3xph2Hehet4uOKV7rQzfGGNp5oEPo0EU7ucgYY5wQ6Gl+CqyFbowx7T/QQ+O5WAvdGGPafaCHxnOxFroxxrT/QE/zcbCsiqpAMN6lGGNMXDkg0EMnFx0os1a6MSaxtftAz7aLRRtjDOCAQI+00C3QjTGJzgGBHm6h27HoxpgE1+4DPTs8QJcdi26MSXTtPtAzkj14XGLHohtjEl67D3QRIcsuFm2MMe0/0MEuFm2MMeCUQE/zWR+6MSbhOSPQbcRFY4xxSKCn+a0P3RiT8BwS6D7KKgOUVVbHuxRjjIkbRwR6dqqdLWqMMY4I9CNni1qgG2MSl0MCPdJCtx2jxpjE5YxAtxEXjTHGIYEe7nIpsEMXjTEJrFGBLiJXiMgXIrJFRGbU8vwUEckXkbXh27TmL7VuKT4PKT63tdCNMQnN09ACIuIGngMuA/KAT0RkqapurLHom6o6vQVqbJQsu1i0MSbBNaaFPhzYoqrbVLUSeAMY17JlHT+7WLQxJtE1JtB7Al/HTOeF59X0bRFZJyKLRKR3bSsSkVtFJEdEcvLz80+g3Lpl24iLxpgE11w7Rf8E9FXVgcA/gFdqW0hVX1DVYao6rHPnzs206RAbcdEYk+gaE+g7gdgWd6/wvChVLVTVSJr+ARjaPOU1XmRMdFVt7U0bY0yb0JhA/wToJyIni4gPuB5YGruAiHSPmRwLbGq+EhsnK81PdVApLrfxXIwxianBo1xUtVpEpgN/B9zAy6q6QUR+DuSo6lLgLhEZC1QD+4EpLVhzrbJjjkXvkOJt7c0bY0zcNRjoAKq6DFhWY97DMY8fAB5o3tKOT1bMAF3faN7ueWOMaRcccaYoxAzQZceiG2MSlOMCvcCORTfGJCjHBHqnFGuhG2MSm2MC3eN20THFaycXGWMSlmMCHexi0caYxOasQE/zU2AtdGNMgnJUoGfbiIvGmATmqEC3EReNMYnMWYGe5uNgWRVVgWC8SzHGmFbnsEAPnS16wFrpxpgE5KhAzw5fLNp2jBpjEpGjAj3SQrdDF40xichhgR45W9Ra6MaYxOOoQM+OjLhofejGmATkqEDPSPbgcYkdi26MSUiOCnQRiV6KzhhjEo2jAh3sYtHGmMTlvEBP89lhi8aYhOS4QM9Osxa6MSYxOS7Qs1KtD90Yk5icF+hpfsoqA5RVVse7FGOMaVUODHQ7ucgYk5gcF+jZkUC3k4uMMQnGcYGeFTlb1E4uMsYkGMcFeqdU63IxxiQmxwV6pA+9wA5dNMYkGMcFeorPQ4rPbS10Y0zCcVygA+HxXKyFboxJLM4MdLtYtDEmATky0LNtPBdjTAJyZKBnpfqty8UYk3CcGehpPvaXVqKq8S7FGGNajUMD3U91UCkut/FcjDGJo1GBLiJXiMgXIrJFRGbUs9y3RURFZFjzlXj8su1YdGNMAmow0EXEDTwHXAn0B74rIv1rWS4duBtY3dxFHq8jp//bjlFjTOJoTAt9OLBFVbepaiXwBjCuluV+AfwKONyM9Z2QIyMuWgvdGJM4GhPoPYGvY6bzwvOiROQsoLeq/qW+FYnIrSKSIyI5+fn5x11sYx05/d9a6MaYxNHknaIi4gJmAz9uaFlVfUFVh6nqsM6dOzd103XqlGItdGNM4mlMoO8EesdM9wrPi0gHzgRWiMh24BxgaTx3jHrcLjJTvNaHboxJKI0J9E+AfiJysoj4gOuBpZEnVbVIVbNVta+q9gVWAWNVNadFKm6kLLtYtDEmwTQY6KpaDUwH/g5sAt5S1Q0i8nMRGdvSBZ6orFQ7/d8Yk1g8jVlIVZcBy2rMe7iOZS9uellNl53m5/M9xfEuwxhjWo0jzxSF8BC6dpSLMSaBODfQU/0cLKuiKhCMdynGGNMqHBvoncLHoh+wVroxJkE4NtCzwxeLth2jxphE4dhAz0oLj+dihy4aYxKEgwM9craotdCNMYnBsYGeHR5xscBO/zfGJAjHBnpGsgePS+zQRWNMwnBsoItI6Fh0a6EbYxKEYwMdQsei77cWujEmQTg70NNsPBdjTOJwdKBn24iLxpgE4uhAz0r12WGLxpiE4exAT/NTVhmgrLI63qUYY0yLc3ig28lFxpjE4ehAz44Euh3pYoxJAI4O9Kzw2aJ2LLoxJhE4O9Cty8UYk0CcHeiR8Vzs0EVjTAJwdKAn+9yk+tzWQjfGJARHBzqEDl20PnRjTCJIgEC3i0UbYxKD8wM91W/juRhjEkICBLoNoWuMSQzOD/Q0H/tLKwkGNd6lGGNMi0qAQPdTHVSKD1fFuxRjjGlRjg/0yOn/1o9ujHE6xwe6nf5vjEkUzg90G6DLGJMgLNCNMcYhHB/onVIiA3RZl4sxxtkcH+get4vMFK+N52KMcTzHBzqEx3OxEReNMQ7XqEAXkStE5AsR2SIiM2p5/jYR+UxE1orIShHp3/ylnrisVJ8dtmiMcbwGA11E3MBzwJVAf+C7tQT266o6QFUHA08As5u70KbIthEXjTEJoDEt9OHAFlXdpqqVwBvAuNgFVLU4ZjIVaFPn2duIi8aYROBpxDI9ga9jpvOAETUXEpE7gHsBHzC6thWJyK3ArQB9+vQ53lpPWFaqn4NlVVQFgnjdCbHbwBiTgJot3VT1OVX9BnA/8GAdy7ygqsNUdVjnzp2ba9MNihyLfsBa6cYYB2tMoO8EesdM9wrPq8sbwPgm1NTsbDwXY0wiaEygfwL0E5GTRcQHXA8sjV1ARPrFTF4NfNl8JTZdVlp4PBc7dNEY42AN9qGrarWITAf+DriBl1V1g4j8HMhR1aXAdBG5FKgCDgA3tWTRxysrNXK2qLXQjTHO1ZidoqjqMmBZjXkPxzy+u5nralaRFnqBHbpojHGwhDjkIyPJg9ctduiiMcbREiLQRYSsVDu5yBjjbAkR6ACdUn3Wh26McbSECfSsNB8F1uVijHGwhAl0G8/FGON0CRPoWdblYoxxuMQJ9DQ/5VUByiqr412KMca0iAQKdDu5yBjjbAkT6Nl2sWhjjMMlTKBnpYbHc7Edo8YYh0qcQLcuF2OMwyVOoIdb6AU24qIxxqESJtCTfW5SfW5roRtjHCthAh1Chy5aH7oxxqkSLNDtYtHGGOdKrEBP9dtl6IwxjpVQgZ6d5rMuF2OMYyVUoGel+dhfWkkwqPEuxRhjml1iBXqqn+qgUny4Kt6lGGNMs0usQA+fXGT96MYYJ0qoQM9Os9P/jTHOlVCBnmUDdBljHCyhAr1TamQ8F2uhG2OcJ7ECPcX60I0xzpVQge5xu8hM8VJoA3QZYxwooQIdQuO57Lc+dGOMA7W/QC/cCv/+LVSfWChnpfqsy8UY40jtL9A3vAP/eBjmnA/bVhz3y7NtxEVjjEO1v0C/6Kdww1sQqIT54+Ctm6Aor9EvtxEXjTFO1f4CHeB/xsDtq2HUTNj8N/jd2fCvX0N1wy3vrFQ/B8uqeHftTg5YsBtjHERU4zNQ1bBhwzQnJ6fpKzrwFfz9/8Hnf4asb8KVT8A3L6lz8f/8335uey2X/aWVuASG9Mlk1KmdGXVaF/p3z0BEml6TMca0EBHJVdVhtT7X7gM94sv34a8/hf3b4PRvwZjHoGPvWhcNBJV1eQdZ/kU+K77Yx7q8IgC6pPsZdWoXRp3WmQv6dSbN72m++owxphk0OdBF5Argt4Ab+IOqPl7j+XuBaUA1kA9MVdWv6ltnswc6hLpcPnoWPnwqNH3Rj+G8u8Djr/dl+YcqWPHFPlZ8kc+Hm/M5VFGN1y2c3bdTOOC78I3OqdZ6N8bEXZMCXUTcwGbgMiAP+AT4rqpujFlmFLBaVctE5IfAxao6qb71tkigRxz8OtQNs2kpdDol1A3T77JGvbQqECT3qwMs/2IfKz7P54u9hwDo3Sk5FO6nduHcb2SR5HW3TO3GGFOPpgb6ucAsVR0Tnn4AQFUfq2P5IcDvVPX8+tbbooEeseV/4a/3Q+GXcOrVcMVjkHnSca0i70AZK8JdM//eUkh5VQCAJK+LVJ+HFL+bVJ+HVL+HFJ87Oi/N7yHF5yHV5ybFf+Q+ze8m2eshyevC73GH7r1ukjwukrxu/B4XHnf73FdtjGl5TQ307wBXqOq08PT3gRGqOr2O5X8H7FHVR2t57lbgVoA+ffoM/eqrentlmkd1Jax6Dv75JGgALrgXzr8bvEnHvarDVQFW/99+1u44SGllNaUV4VtlgLLKakorjtyXVlZTVhGgMhA87u14XII/JuCTvG7S3FX0cBfRTfbTQcop9XemNKk7Vf5M/F43fo8bv9eF3xP6ovB7XOHp8GOPK7xcaH0ZSR4ykr2kJ3nwexLk10ZVeegQ15J9kNYFMnqCLyXeVRlzXFot0EXke8B0YKSq1nsMYau00GMV7YT3ZsKGxZCSBR16QXImJHcK3aeE72ubl9QR3Ce2g7SyOkh5ZTjgK6spqQhQVlFNRXWQiuoAhyuroGQfntK9eEr34C/fi//wPlIO7yO1Ip/0qnzSqwpIDR6qdf0V+NhNNrs1izzNIi+QxS6y2KnZ7NIsdmsWFfjqrdHvcZGR7I2GfEaSt5ZpT3R+mv/IF4AqaOzj8OdJw9Ohx6GFYj9pbpdEb57oveuo6WOec4emIfTlWlYZuVVTVhmgouwQenAHruI8fIfy8JXkkVK2i7TDu+hQsZv06v3H/NtL3RkUebtQ4u9KSVI3ypO6cTilG5Up3alO70kgvRt+XzJ+r4skjzv6pZjic5Pi85Dsc5Pic+NthV9V1YEgh6uDVAeCpCd5cbtsn04iqi/QG5NSO4HYw0V6hefV3MilwEwaEeZx0aEnTJgHQ6fAp29AWSGUHwi12Mr2w+GDoPW0pv0dILljKOT96SAuQEAkfO+KeXzk3ofgE6EDHJkfrIZDe+DQbijZe+x2xQ1pXaFDd0g/A9K7Q3o3yOgRuvd3CL22KA9/0df0Lcqjb1EeFH0OJXuOKb06OZvK1B5UpPagPLk7pUldKFM/JQEPJQEvxdUeiqo9HKzycLDKzf5iN/sK3Gw87CK/QjgU8KKtfMqCECSJSpKpJEUqSKKCFCpIppJ0KaOHFNBLCugp+fSSAvpJPlly9JdehXrYpVnskC7scw2hwNuV/d5ulHo7kVZ1gMzqfXQK5JNdkU+X8jxO5lMypeSodQRVKKADu7QTu8NfkJEvyt3aiZ2aTT4dcbvdJHvd4YD3kOwNBX0k8KPh73WT5HVTGQhyuCpAeWWAw9Whx0du4enqAOWVQSrCj6sCR74SXRIal6hLup/O6aH7LulJRx5n+OmclkSXDL/t70kgjWmhewjtFL2EUJB/AtygqhtilhkCLCLUkv+yMRtu9RZ6Q4JBqCiG8v2hoC87ELovPxAzL3xfURxufurR9xqsMY9jl0HB5QkFdnp3yAiHdXqPI6Gd2hlcJ/hHWF0BxbtCX1TR29dHT1eVHvdq1e0n6Eki4E4i4PKh4kFd7ug94kbFjbo8MY9Dz+PyoOIK3bvcgCDVh5GqMlzV5dGbO3Dk3hM43PA/1eWnPKUHFWm9qE7rSbBDbyTzJFyZffBm9SWlUw/8Xs9xHZ0UOFxC5YGvqdr/NcGDX6NFO5HinbgO7cJTsgtf6W481Ue/fwFxU+LtTJG3C/s9ncl3ZbNPstkd/pWUF+jE7uoUyquClFWGQtvncZEcDvdI11qqFzq5D5PpKiPTVUZHVxkZlJJBKWmUkqalpAYP4QuUU0wq+4Id2F2dxtdVaWwrS2VLWTL52uGYX2TpSZ5o0HdOT6Jzmp9knwufO9RN53O7Yu7d0Wm/24Uv3IXnC3fb+WJv7tDNZb8UWlVzHLZ4FfA0ocMWX1bVX4rIz4EcVV0qIu8DA4Dd4ZfsUNWx9a2zzQV6olCFikOh/uTq8tB9VTlUH4aqMqg6HPNceF7sc5H5werQTYNHHgcDR99roI75QfAmgzcldPOlhKdTj37sTQZf6rHL+lKhQ+/QF19rH0qqCoeLoHhnqBuvOPJFuTM8Ly/0hRqo8SPVkxT6ss7oGfoyryoPrefwwdB9+UGorL1bLcrlCXX/+VJCy1cU17pYwJdBhT+LUm8nilwd2S8d2RvIYGdVGjsq09henkJJtRBQFwFcBJHwfcxjrX2+EnoNgIQ70bwu8LpdeN0SDnnB5xG84ceR+9Dj0LS4XKEv/nBjABHEFWoIuFwuXAKC4HKBiOAScIkgEP2CDqoSVEUVguHuvtBjDU1zZDr2XsN1S3h9LhFEiD4msu3wj+1IHcTMi+0e9LgFt8sV7R70NjDtcQln9uxA704ntv8mMU4sMqatUIXSghphHxP6JXvBlwZJHUIBnRTuzmto2pty9BdYVTmU5kNJPpTuC+3sLd139HRk3uGieLwTJySIEMQV/TKp7f6I0PuhR02BEgnmmOmYZRSJztPo9JFlo9N65Lkjy4fWoBqzBo1dI9E1hL6AjmRsZP72QT/mvOtqPa6kQU3tQzfGHA8RSOscuvUY0nLb8SZDxz6hW0OqK8Lhvy+0/yjyi0kD4V9ZgXC3YSD8OFjjcTBmuQDRqJQjEXbU45rPRadDYXjUuqKPQ/cuDeAKBvDEbj+ybHT7MY5plGq9k0e6P2MfR/bga+Oer/HvVQlHdbinNRj+IgjqkS+MYHhVQeDM0/sf+3/UDCzQjUkEHn/oyK4OveJdiSNFvqrizc5gMcYYh7BAN8YYh7BAN8YYh7BAN8YYh7BAN8YYh7BAN8YYh7BAN8YYh7BAN8YYh4jbqf8ikg+c6IDo2UBBM5bT3Ky+prH6mq6t12j1nbiTVLVzbU/ELdCbQkRy6hrLoC2w+prG6mu6tl6j1dcyrMvFGGMcwgLdGGMcor0G+gvxLqABVl/TWH1N19ZrtPpaQLvsQzfGGHOs9tpCN8YYU4MFujHGOESbDnQRuUJEvhCRLSIyo5bn/SLyZvj51SLStxVr6y0iy0Vko4hsEJG7a1nmYhEpEpG14dvDrVVfePvbReSz8LaPud6fhDwTfv/WichZrVjbqTHvy1oRKRaRH9VYptXfPxF5WUT2icj6mHmdROQfIvJl+D6zjtfeFF7mSxG5qZVqe1JEPg///y0WkY51vLbez0IL1zhLRHbG/D9eVcdr6/17b8H63oypbbuIrK3jta3yHjZJ6MKqbe9G6ILUW4FTAB/wKdC/xjK3A3PCj68H3mzF+roDZ4UfpwOba6nvYuDPcXwPtwPZ9Tx/FfBXQhdbOQdYHcf/6z2ETpiI6/sHXAScBayPmfcEMCP8eAbwq1pe1wnYFr7PDD/ObIXaLgc84ce/qq22xnwWWrjGWcBPGvEZqPfvvaXqq/H8r4GH4/keNuXWllvow4EtqrpNVSuBN4BxNZYZB7wSfrwIuESkdS4Dr6q7VXVN+PEhYBPQszW23YzGAfM1ZBXQUUS6x6GOS4CtqnqiZw43G1X9ENhfY3bs5+wVYHwtLx0D/ENV96vqAeAfwBUtXZuqvqeq1eHJVUBcrzFXx/vXGI35e2+y+uoLZ8dE4I/Nvd3W0pYDvSfwdcx0HscGZnSZ8Ie6CMhqlepihLt6hgCra3n6XBH5VET+KiJntG5lKPCeiOSKyK21PN+Y97g1XE/df0TxfP8iuqrq7vDjPUDXWpZpC+/lVEK/uGrT0GehpU0Pdwu9XEeXVVt4/y4E9qrql3U8H+/3sEFtOdDbBRFJA94GfqSqxTWeXkOoG2EQ8CywpJXLu0BVzwKuBO4QkYtaefsNEhEfMBZYWMvT8X7/jqGh395t7lhfEZkJVAML6lgknp+F54FvAIOB3YS6Ndqi71J/67zN/z215UDfCfSOme4VnlfrMiLiAToAha1SXWibXkJhvkBV36n5vKoWq2pJ+PEywCsi2a1Vn6ruDN/vAxYT+lkbqzHvcUu7ElijqntrPhHv9y/G3khXVPh+Xy3LxO29FJEpwDXAjeEvnGM04rPQYlR1r6oGVDUIvFjHtuP6WQznx3XAm3UtE8/3sLHacqB/AvQTkZPDrbjrgaU1llkKRI4m+A7wQV0f6OYW7m97CdikqrPrWKZbpE9fRIYTer9b5QtHRFJFJD3ymNDOs/U1FlsKTA4f7XIOUBTTtdBa6mwVxfP9qyH2c3YT8G4ty/wduFxEMsNdCpeH57UoEbkCuA8Yq6pldSzTmM9CS9YYu1/m2jq23Zi/95Z0KfC5qubV9mS838NGi/de2fpuhI7C2Exo7/fM8LyfE/rwAiQR+qm+BfgPcEor1nYBoZ/e64C14dtVwG3AbeFlpgMbCO2xXwWc14r1nRLe7qfhGiLvX2x9AjwXfn8/A4a18v9vKqGA7hAzL67vH6Evl91AFaF+3FsI7Zf5X+BL4H2gU3jZYcAfYl47NfxZ3ALc3Eq1bSHU9xz5DEaO+uoBLKvvs9CK79+r4c/XOkIh3b1mjeHpY/7eW6O+8Px5kc9dzLJxeQ+bcrNT/40xxiHacpeLMcaY42CBbowxDmGBbowxDmGBbowxDmGBbowxDmGBbowxDmGBbowxDvH/AdtTX6a9y/n+AAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "base_history[history_columns].plot()\n",
    "lstm_history[history_columns].plot()\n",
    "gru_history[history_columns].plot()\n",
    "attn_history[history_columns].plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_model = tf.keras.models.load_model(f\"{MODEL_DIR}/baseline\")\n",
    "lstm_model = tf.keras.models.load_model(f\"{MODEL_DIR}/lstm\")\n",
    "gru_model = tf.keras.models.load_model(f\"{MODEL_DIR}/gru\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "Unable to restore object of class 'Attention' likely due to name conflict with built-in Keras class '<class 'keras.layers.dense_attention.Attention'>'. To override the built-in Keras definition of the object, decorate your class with `@keras.utils.register_keras_serializable` and include that file in your program, or pass your class in a `keras.utils.CustomObjectScope` that wraps this load call.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[1;32m/Users/nowadmin/Documents/School Folder/CS 437/Lab/Final Project/tf_model.ipynb Cell 65'\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> <a href='vscode-notebook-cell:/Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/tf_model.ipynb#ch0000068?line=0'>1</a>\u001b[0m attn_model \u001b[39m=\u001b[39m tf\u001b[39m.\u001b[39;49mkeras\u001b[39m.\u001b[39;49mmodels\u001b[39m.\u001b[39;49mload_model(\u001b[39mf\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39m{\u001b[39;49;00mMODEL_DIR\u001b[39m}\u001b[39;49;00m\u001b[39m/attention\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
      "File \u001b[0;32m~/Documents/School Folder/CS 437/Lab/Final Project/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py:67\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py?line=64'>65</a>\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mException\u001b[39;00m \u001b[39mas\u001b[39;00m e:  \u001b[39m# pylint: disable=broad-except\u001b[39;00m\n\u001b[1;32m     <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py?line=65'>66</a>\u001b[0m   filtered_tb \u001b[39m=\u001b[39m _process_traceback_frames(e\u001b[39m.\u001b[39m__traceback__)\n\u001b[0;32m---> <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py?line=66'>67</a>\u001b[0m   \u001b[39mraise\u001b[39;00m e\u001b[39m.\u001b[39mwith_traceback(filtered_tb) \u001b[39mfrom\u001b[39;00m \u001b[39mNone\u001b[39m\n\u001b[1;32m     <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py?line=67'>68</a>\u001b[0m \u001b[39mfinally\u001b[39;00m:\n\u001b[1;32m     <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py?line=68'>69</a>\u001b[0m   \u001b[39mdel\u001b[39;00m filtered_tb\n",
      "File \u001b[0;32m~/Documents/School Folder/CS 437/Lab/Final Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py:532\u001b[0m, in \u001b[0;36mKerasObjectLoader._revive_layer_or_model_from_config\u001b[0;34m(self, metadata, node_id)\u001b[0m\n\u001b[1;32m    <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py?line=529'>530</a>\u001b[0m builtin_layer \u001b[39m=\u001b[39m layers_module\u001b[39m.\u001b[39mget_builtin_layer(class_name)\n\u001b[1;32m    <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py?line=530'>531</a>\u001b[0m \u001b[39mif\u001b[39;00m builtin_layer:\n\u001b[0;32m--> <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py?line=531'>532</a>\u001b[0m   \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\n\u001b[1;32m    <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py?line=532'>533</a>\u001b[0m       \u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39mUnable to restore object of class \u001b[39m\u001b[39m\\'\u001b[39;00m\u001b[39m{\u001b[39;00mclass_name\u001b[39m}\u001b[39;00m\u001b[39m\\'\u001b[39;00m\u001b[39m likely due to \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m    <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py?line=533'>534</a>\u001b[0m       \u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39mname conflict with built-in Keras class \u001b[39m\u001b[39m\\'\u001b[39;00m\u001b[39m{\u001b[39;00mbuiltin_layer\u001b[39m}\u001b[39;00m\u001b[39m\\'\u001b[39;00m\u001b[39m. To \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m    <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py?line=534'>535</a>\u001b[0m       \u001b[39m'\u001b[39m\u001b[39moverride the built-in Keras definition of the object, decorate \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m    <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py?line=535'>536</a>\u001b[0m       \u001b[39m'\u001b[39m\u001b[39myour class with `@keras.utils.register_keras_serializable` and \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m    <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py?line=536'>537</a>\u001b[0m       \u001b[39m'\u001b[39m\u001b[39minclude that file in your program, or pass your class in a \u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m    <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py?line=537'>538</a>\u001b[0m       \u001b[39m'\u001b[39m\u001b[39m`keras.utils.CustomObjectScope` that wraps this load call.\u001b[39m\u001b[39m'\u001b[39m) \u001b[39mfrom\u001b[39;00m \u001b[39me\u001b[39;00m\n\u001b[1;32m    <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py?line=538'>539</a>\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m    <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/keras/saving/saved_model/load.py?line=539'>540</a>\u001b[0m   \u001b[39mraise\u001b[39;00m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Unable to restore object of class 'Attention' likely due to name conflict with built-in Keras class '<class 'keras.layers.dense_attention.Attention'>'. To override the built-in Keras definition of the object, decorate your class with `@keras.utils.register_keras_serializable` and include that file in your program, or pass your class in a `keras.utils.CustomObjectScope` that wraps this load call."
     ]
    }
   ],
   "source": [
    "attn_model = tf.keras.models.load_model(f\"{MODEL_DIR}/attention\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_time():\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Scratch"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "final_lab",
   "language": "python",
   "name": "final_lab"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}