{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Notebook for training and testing AI models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "import pandas as pd\n",
    "from attention import Attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.8.0\n"
     ]
    }
   ],
   "source": [
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CONSTANTS\n",
    "RAW_SLEEP_DATA_PATH = \".data/raw_bed_sleep-state.csv\"\n",
    "CLEANED_SLEEP_DATA_PATH = \".data/clean_bed_sleep-state.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Parameters and Hyper-parameters\n",
    "SLEEP_STAGES = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cleaning Raw Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "import datetime\n",
    "import itertools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "datetime.datetime(2022, 4, 21, 10, 19, tzinfo=datetime.timezone(datetime.timedelta(seconds=7200)))"
      ]
     },
     "execution_count": 104,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datetime.datetime.strptime(\"2022-04-21T10:18:00+02:00\",\"%Y-%m-%dT%H:%M:%S%z\") + datetime.timedelta(minutes=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [],
   "source": [
    "def stage_probability(stage, to_test):\n",
    "    return 1.0 if stage == to_test else 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "6.833333333333333"
      ]
     },
     "execution_count": 150,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "((start_time - in_bed_time).seconds)/3600"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 201,
   "metadata": {},
   "outputs": [],
   "source": [
    "cleaned_info = []\n",
    "date_seen = set()\n",
    "previous_duration = 60\n",
    "with open(RAW_SLEEP_DATA_PATH, mode ='r') as raw_file:\n",
    "    csvFile = csv.reader(raw_file)\n",
    "    # max_count = 1\n",
    "    # stuff = set()\n",
    "    in_bed_time = None\n",
    "    current_sleep_id = -1\n",
    "    for index, lines in enumerate(csvFile):\n",
    "        if index == 0:\n",
    "            cleaned_info.append([\n",
    "                \"sleep_id\",\n",
    "                \"sleep_begin\",\n",
    "                \"stage_start\",\n",
    "                \"time_since_begin_sec\",\n",
    "                \"stage_duration_sec\",\n",
    "                \"stage_end\", \n",
    "                \"stage_value\",\n",
    "                \"awake_probability\",\n",
    "                \"light_probability\",\n",
    "                \"deep_probability\",\n",
    "                \"rem_probability\",\n",
    "            ])\n",
    "            continue\n",
    "        start_time = datetime.datetime.strptime(lines[0],\"%Y-%m-%dT%H:%M:%S%z\")\n",
    "        if start_time in date_seen:\n",
    "            continue\n",
    "        date_seen.add(start_time)\n",
    "        if not in_bed_time or in_bed_time > start_time:\n",
    "            current_sleep_id += 1\n",
    "            in_bed_time = start_time\n",
    "        # for duration, stage in enumerate(\n",
    "        # for offset, (duration, stage) in enumerate(\n",
    "        #     zip(\n",
    "        #         # itertools.accumulate(lines[1].strip(\"[]\").split(\",\"), lambda x,y: int(x)+int(y)//60, initial = 0), \n",
    "        #         map(int, lines[1].strip(\"[]\").split(\",\"))\n",
    "        #         map(int, lines[2].strip(\"[]\").split(\",\"))\n",
    "        #     )\n",
    "        #         # map(int, lines[2].strip(\"[]\").split(\",\"))\n",
    "        # ):\n",
    "        for offset, (duration, stage) in enumerate(zip(map(int, lines[1].strip(\"[]\").split(\",\")), map(int, lines[2].strip(\"[]\").split(\",\")))):\n",
    "            # print(f\"{(index, subindex) = }, {duration = }, {stage = }\")\n",
    "            # print(f\"{(index, duration) = } {stage = }\")\n",
    "            current_time = start_time + datetime.timedelta(seconds=offset*previous_duration)\n",
    "            cleaned_info.append([\n",
    "                current_sleep_id,\n",
    "                in_bed_time,\n",
    "                current_time, \n",
    "                (current_time - in_bed_time).seconds,\n",
    "                duration, \n",
    "                current_time + datetime.timedelta(seconds=duration), \n",
    "                stage,\n",
    "                stage_probability(0, stage),\n",
    "                stage_probability(1, stage),\n",
    "                stage_probability(2, stage),\n",
    "                stage_probability(3, stage),\n",
    "            ])\n",
    "            previous_duration = duration\n",
    "            # print(f\"{(index, subindex) = }, {val = }\")\n",
    "        # print(list())\n",
    "        # if index >= max_count:\n",
    "        #     break\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 202,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished Writing Cleaned Data\n"
     ]
    }
   ],
   "source": [
    "with open(CLEANED_SLEEP_DATA_PATH, 'w') as clean_file:\n",
    "    write = csv.writer(clean_file)\n",
    "    write.writerows(cleaned_info)\n",
    "print(\"Finished Writing Cleaned Data\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating DataFrame from clean raw data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the cleaned data\n",
    "sleep_df_raw = pd.read_csv(CLEANED_SLEEP_DATA_PATH)#, parse_dates=[\"start\", \"end\"], infer_datetime_format=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess data: \n",
    "#   1. convert to datetime\n",
    "sleep_df_raw[\"sleep_begin\"] = pd.to_datetime(sleep_df_raw[\"sleep_begin\"], utc=True)\n",
    "sleep_df_raw[\"stage_start\"] = pd.to_datetime(sleep_df_raw[\"stage_start\"], utc=True)\n",
    "sleep_df_raw[\"stage_end\"] = pd.to_datetime(sleep_df_raw[\"stage_end\"], utc=True)\n",
    "#   2. Separate time, hour and minute\n",
    "#   MAYBE 3. smaller units: int16 or int8 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_minute(row, index):\n",
    "    return row[index].time().minute\n",
    "\n",
    "def get_hour(row, index):\n",
    "    return row[index].time().hour"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "sleep_df_raw[\"stage_start_hour\"] = sleep_df_raw.apply (lambda row: get_hour(row, \"stage_start\"), axis=1)\n",
    "sleep_df_raw[\"stage_start_minute\"] = sleep_df_raw.apply (lambda row: get_minute(row, \"stage_start\"), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 551042 entries, 0 to 551041\n",
      "Data columns (total 13 columns):\n",
      " #   Column                Non-Null Count   Dtype              \n",
      "---  ------                --------------   -----              \n",
      " 0   sleep_id              551042 non-null  int64              \n",
      " 1   sleep_begin           551042 non-null  datetime64[ns, UTC]\n",
      " 2   stage_start           551042 non-null  datetime64[ns, UTC]\n",
      " 3   time_since_begin_sec  551042 non-null  int64              \n",
      " 4   stage_duration_sec    551042 non-null  int64              \n",
      " 5   stage_end             551042 non-null  datetime64[ns, UTC]\n",
      " 6   stage_value           551042 non-null  int64              \n",
      " 7   awake_probability     551042 non-null  float64            \n",
      " 8   light_probability     551042 non-null  float64            \n",
      " 9   deep_probability      551042 non-null  float64            \n",
      " 10  rem_probability       551042 non-null  float64            \n",
      " 11  stage_start_hour      551042 non-null  int64              \n",
      " 12  stage_start_minute    551042 non-null  int64              \n",
      "dtypes: datetime64[ns, UTC](3), float64(4), int64(6)\n",
      "memory usage: 54.7 MB\n"
     ]
    }
   ],
   "source": [
    "sleep_df_raw.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sleep_id</th>\n",
       "      <th>sleep_begin</th>\n",
       "      <th>stage_start</th>\n",
       "      <th>time_since_begin_sec</th>\n",
       "      <th>stage_duration_sec</th>\n",
       "      <th>stage_end</th>\n",
       "      <th>stage_value</th>\n",
       "      <th>awake_probability</th>\n",
       "      <th>light_probability</th>\n",
       "      <th>deep_probability</th>\n",
       "      <th>rem_probability</th>\n",
       "      <th>stage_start_hour</th>\n",
       "      <th>stage_start_minute</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:19:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8</td>\n",
       "      <td>18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:19:00+00:00</td>\n",
       "      <td>60</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:20:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8</td>\n",
       "      <td>19</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:20:00+00:00</td>\n",
       "      <td>120</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:21:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8</td>\n",
       "      <td>20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:21:00+00:00</td>\n",
       "      <td>180</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:22:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8</td>\n",
       "      <td>21</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:22:00+00:00</td>\n",
       "      <td>240</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:23:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8</td>\n",
       "      <td>22</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551037</th>\n",
       "      <td>1132</td>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:17:00+00:00</td>\n",
       "      <td>25560</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:18:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>13</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551038</th>\n",
       "      <td>1132</td>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:18:00+00:00</td>\n",
       "      <td>25620</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:19:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>13</td>\n",
       "      <td>18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551039</th>\n",
       "      <td>1132</td>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:19:00+00:00</td>\n",
       "      <td>25680</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:20:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>13</td>\n",
       "      <td>19</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551040</th>\n",
       "      <td>1132</td>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:20:00+00:00</td>\n",
       "      <td>25740</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:21:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>13</td>\n",
       "      <td>20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551041</th>\n",
       "      <td>1132</td>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:21:00+00:00</td>\n",
       "      <td>25800</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:22:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>13</td>\n",
       "      <td>21</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>551042 rows × 13 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        sleep_id               sleep_begin               stage_start  \\\n",
       "0              0 2022-04-21 08:18:00+00:00 2022-04-21 08:18:00+00:00   \n",
       "1              0 2022-04-21 08:18:00+00:00 2022-04-21 08:19:00+00:00   \n",
       "2              0 2022-04-21 08:18:00+00:00 2022-04-21 08:20:00+00:00   \n",
       "3              0 2022-04-21 08:18:00+00:00 2022-04-21 08:21:00+00:00   \n",
       "4              0 2022-04-21 08:18:00+00:00 2022-04-21 08:22:00+00:00   \n",
       "...          ...                       ...                       ...   \n",
       "551037      1132 2019-02-11 06:11:00+00:00 2019-02-11 13:17:00+00:00   \n",
       "551038      1132 2019-02-11 06:11:00+00:00 2019-02-11 13:18:00+00:00   \n",
       "551039      1132 2019-02-11 06:11:00+00:00 2019-02-11 13:19:00+00:00   \n",
       "551040      1132 2019-02-11 06:11:00+00:00 2019-02-11 13:20:00+00:00   \n",
       "551041      1132 2019-02-11 06:11:00+00:00 2019-02-11 13:21:00+00:00   \n",
       "\n",
       "        time_since_begin_sec  stage_duration_sec                 stage_end  \\\n",
       "0                          0                  60 2022-04-21 08:19:00+00:00   \n",
       "1                         60                  60 2022-04-21 08:20:00+00:00   \n",
       "2                        120                  60 2022-04-21 08:21:00+00:00   \n",
       "3                        180                  60 2022-04-21 08:22:00+00:00   \n",
       "4                        240                  60 2022-04-21 08:23:00+00:00   \n",
       "...                      ...                 ...                       ...   \n",
       "551037                 25560                  60 2019-02-11 13:18:00+00:00   \n",
       "551038                 25620                  60 2019-02-11 13:19:00+00:00   \n",
       "551039                 25680                  60 2019-02-11 13:20:00+00:00   \n",
       "551040                 25740                  60 2019-02-11 13:21:00+00:00   \n",
       "551041                 25800                  60 2019-02-11 13:22:00+00:00   \n",
       "\n",
       "        stage_value  awake_probability  light_probability  deep_probability  \\\n",
       "0                 0                1.0                0.0               0.0   \n",
       "1                 0                1.0                0.0               0.0   \n",
       "2                 0                1.0                0.0               0.0   \n",
       "3                 0                1.0                0.0               0.0   \n",
       "4                 0                1.0                0.0               0.0   \n",
       "...             ...                ...                ...               ...   \n",
       "551037            1                0.0                1.0               0.0   \n",
       "551038            1                0.0                1.0               0.0   \n",
       "551039            1                0.0                1.0               0.0   \n",
       "551040            1                0.0                1.0               0.0   \n",
       "551041            1                0.0                1.0               0.0   \n",
       "\n",
       "        rem_probability  stage_start_hour  stage_start_minute  \n",
       "0                   0.0                 8                  18  \n",
       "1                   0.0                 8                  19  \n",
       "2                   0.0                 8                  20  \n",
       "3                   0.0                 8                  21  \n",
       "4                   0.0                 8                  22  \n",
       "...                 ...               ...                 ...  \n",
       "551037              0.0                13                  17  \n",
       "551038              0.0                13                  18  \n",
       "551039              0.0                13                  19  \n",
       "551040              0.0                13                  20  \n",
       "551041              0.0                13                  21  \n",
       "\n",
       "[551042 rows x 13 columns]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sleep_df_raw"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "sleep_data = sleep_df_raw[[\"sleep_id\", \"stage_start_hour\", \"stage_start_minute\", \"awake_probability\", \"rem_probability\",\"light_probability\", \"deep_probability\"]]\n",
    "sleep_data.insert(loc=1, column=\"minutes_since_begin\" , value= sleep_df_raw[\"time_since_begin_sec\"]//60)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   sleep_id  minutes_since_begin  stage_start_hour  stage_start_minute  \\\n",
      "0         0                    0                 8                  18   \n",
      "1         0                    1                 8                  19   \n",
      "2         0                    2                 8                  20   \n",
      "3         0                    3                 8                  21   \n",
      "4         0                    4                 8                  22   \n",
      "\n",
      "   awake_probability  rem_probability  light_probability  deep_probability  \n",
      "0                1.0              0.0                0.0               0.0  \n",
      "1                1.0              0.0                0.0               0.0  \n",
      "2                1.0              0.0                0.0               0.0  \n",
      "3                1.0              0.0                0.0               0.0  \n",
      "4                1.0              0.0                0.0               0.0  \n",
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 551042 entries, 0 to 551041\n",
      "Data columns (total 8 columns):\n",
      " #   Column               Non-Null Count   Dtype  \n",
      "---  ------               --------------   -----  \n",
      " 0   sleep_id             551042 non-null  int64  \n",
      " 1   minutes_since_begin  551042 non-null  int64  \n",
      " 2   stage_start_hour     551042 non-null  int64  \n",
      " 3   stage_start_minute   551042 non-null  int64  \n",
      " 4   awake_probability    551042 non-null  float64\n",
      " 5   rem_probability      551042 non-null  float64\n",
      " 6   light_probability    551042 non-null  float64\n",
      " 7   deep_probability     551042 non-null  float64\n",
      "dtypes: float64(4), int64(4)\n",
      "memory usage: 33.6 MB\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "print(sleep_data.head())\n",
    "print(sleep_data.info())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sleep_id</th>\n",
       "      <th>minutes_since_begin</th>\n",
       "      <th>stage_start_hour</th>\n",
       "      <th>stage_start_minute</th>\n",
       "      <th>awake_probability</th>\n",
       "      <th>rem_probability</th>\n",
       "      <th>light_probability</th>\n",
       "      <th>deep_probability</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>8</td>\n",
       "      <td>18</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>8</td>\n",
       "      <td>19</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>8</td>\n",
       "      <td>20</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>8</td>\n",
       "      <td>21</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>4</td>\n",
       "      <td>8</td>\n",
       "      <td>22</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551037</th>\n",
       "      <td>1132</td>\n",
       "      <td>426</td>\n",
       "      <td>13</td>\n",
       "      <td>17</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551038</th>\n",
       "      <td>1132</td>\n",
       "      <td>427</td>\n",
       "      <td>13</td>\n",
       "      <td>18</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551039</th>\n",
       "      <td>1132</td>\n",
       "      <td>428</td>\n",
       "      <td>13</td>\n",
       "      <td>19</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551040</th>\n",
       "      <td>1132</td>\n",
       "      <td>429</td>\n",
       "      <td>13</td>\n",
       "      <td>20</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551041</th>\n",
       "      <td>1132</td>\n",
       "      <td>430</td>\n",
       "      <td>13</td>\n",
       "      <td>21</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>551042 rows × 8 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        sleep_id  minutes_since_begin  stage_start_hour  stage_start_minute  \\\n",
       "0              0                    0                 8                  18   \n",
       "1              0                    1                 8                  19   \n",
       "2              0                    2                 8                  20   \n",
       "3              0                    3                 8                  21   \n",
       "4              0                    4                 8                  22   \n",
       "...          ...                  ...               ...                 ...   \n",
       "551037      1132                  426                13                  17   \n",
       "551038      1132                  427                13                  18   \n",
       "551039      1132                  428                13                  19   \n",
       "551040      1132                  429                13                  20   \n",
       "551041      1132                  430                13                  21   \n",
       "\n",
       "        awake_probability  rem_probability  light_probability  \\\n",
       "0                     1.0              0.0                0.0   \n",
       "1                     1.0              0.0                0.0   \n",
       "2                     1.0              0.0                0.0   \n",
       "3                     1.0              0.0                0.0   \n",
       "4                     1.0              0.0                0.0   \n",
       "...                   ...              ...                ...   \n",
       "551037                0.0              0.0                1.0   \n",
       "551038                0.0              0.0                1.0   \n",
       "551039                0.0              0.0                1.0   \n",
       "551040                0.0              0.0                1.0   \n",
       "551041                0.0              0.0                1.0   \n",
       "\n",
       "        deep_probability  \n",
       "0                    0.0  \n",
       "1                    0.0  \n",
       "2                    0.0  \n",
       "3                    0.0  \n",
       "4                    0.0  \n",
       "...                  ...  \n",
       "551037               0.0  \n",
       "551038               0.0  \n",
       "551039               0.0  \n",
       "551040               0.0  \n",
       "551041               0.0  \n",
       "\n",
       "[551042 rows x 8 columns]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sleep_data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Development"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Helper functions and class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "TEST_SIZE = 122\n",
    "VALIDATION_SIZE = 183\n",
    "TIME_STEP_INPUT = 10 # in minutes\n",
    "BATCH_SIZE = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def training_test_split_by_unique_index(data, index: str, test_size: int = 10):\n",
    "    test_ids = np.random.choice(data[index].unique(), size = test_size, replace=False)\n",
    "    return data[~data[index].isin(test_ids)], data[data[index].isin(test_ids)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Adapted from https://www.tensorflow.org/tutorials/structured_data/time_series\n",
    "class WindowGenerator():\n",
    "    def __init__(self, data, index: str = \"sleep_id\", input_width: int = TIME_STEP_INPUT, validation_size: int = VALIDATION_SIZE, test_size: int = TEST_SIZE, input_feature_slice: slice = slice(1,100), label_feature_slice: slice = slice(-4,100), generate_data_now: bool = True):\n",
    "        # Partition data\n",
    "        self.training, self.testing = training_test_split_by_unique_index(sleep_data, index, test_size)\n",
    "        self.training, self.validation = training_test_split_by_unique_index(self.training, index, validation_size)\n",
    "\n",
    "        # Window paramters\n",
    "        self.input_width = input_width\n",
    "        self.label_width = 1\n",
    "        self.shift = 1\n",
    "\n",
    "        self.total_window_size = self.input_width + self.shift\n",
    "\n",
    "        self.input_slice = slice(0, input_width)\n",
    "        self.input_indices = np.arange(self.total_window_size)[self.input_slice]\n",
    "\n",
    "        self.label_start = self.total_window_size - self.label_width\n",
    "        self.labels_slice = slice(self.label_start, None)\n",
    "        self.label_indices = np.arange(self.total_window_size)[self.labels_slice]\n",
    "\n",
    "        self.input_feature_slice = input_feature_slice\n",
    "        self.label_feature_slice = label_feature_slice\n",
    "\n",
    "        self.sample_ds = self.make_dataset(sleep_data[sleep_data[index] == 0])\n",
    "\n",
    "        if generate_data_now:\n",
    "            self.training_ds = self.make_dataset(self.training)\n",
    "            self.validation_ds = self.make_dataset(self.validation)\n",
    "            self.testing_ds = self.make_dataset(self.testing)\n",
    "\n",
    "\n",
    "    def __repr__(self):\n",
    "        return \"WindowGenerator:\\n\\t\" +'\\n\\t'.join([\n",
    "            f'Total window size: {self.total_window_size}',\n",
    "            f'Input indices: {self.input_indices}',\n",
    "            f'Label indices: {self.label_indices}',\n",
    "        ])\n",
    "\n",
    "    def split_window(self, features):\n",
    "        inputs = features[:, self.input_slice, self.input_feature_slice]\n",
    "        labels = features[:, self.labels_slice, self.label_feature_slice]\n",
    "        inputs.set_shape([None, self.input_width, None])\n",
    "        labels.set_shape([None, self.label_width, None])\n",
    "        return inputs, labels\n",
    "\n",
    "    def make_dataset(self, data, index_group: str = \"sleep_id\", sort_by: str = \"minutes_since_begin\"):\n",
    "        ds_all = None\n",
    "        for i_group in data[index_group].unique():\n",
    "            subset_data = np.array(data[data[index_group] == i_group].sort_values(by=[sort_by]), dtype=np.float32)\n",
    "            ds = tf.keras.utils.timeseries_dataset_from_array(\n",
    "                data=subset_data,\n",
    "                targets=None,\n",
    "                sequence_length=self.total_window_size,\n",
    "                sequence_stride=1,\n",
    "                shuffle=False,\n",
    "                batch_size=BATCH_SIZE,)\n",
    "            ds_all = ds if ds_all is None else ds_all.concatenate(ds)\n",
    "        ds_all = ds_all.map(self.split_window)\n",
    "\n",
    "        return ds_all\n",
    "\n",
    "    # def generate_all_datasets(self):\n",
    "    #     self._training_ds = self.make_dataset(self.training)\n",
    "    #     self._validation_ds = self.make_dataset(self.validation)\n",
    "    #     self._testing_ds = self.make_dataset(self.testing)\n",
    "\n",
    "    # def training_dataset(self):\n",
    "    #     if self._training_ds is None:\n",
    "    #         self._training_ds = self.make_dataset(self.training)\n",
    "    #     return self._training_ds\n",
    "\n",
    "    # def validation_dataset(self):\n",
    "    #     if self._validation_ds is None:\n",
    "    #         self._validation_ds = self.make_dataset(self.validation)\n",
    "    #     return self._validation_ds\n",
    "\n",
    "    # def test_dataset(self):\n",
    "    #     if self._testing_ds is None:\n",
    "    #         self._testing_ds = self.make_dataset(self.testing)\n",
    "    #     return self._testing_ds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Prep\n",
    "\n",
    "All inputs follow: (batch_size, timesteps, input_dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "WindowGenerator:\n",
       "\tTotal window size: 11\n",
       "\tInput indices: [0 1 2 3 4 5 6 7 8 9]\n",
       "\tLabel indices: [10]"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wg = WindowGenerator(sleep_data)\n",
    "wg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = wg.sample_ds.take(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_array = list(sample.as_numpy_iterator())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[18.,  8., 36.,  1.,  0.,  0.,  0.],\n",
       "        [19.,  8., 37.,  1.,  0.,  0.,  0.],\n",
       "        [20.,  8., 38.,  1.,  0.,  0.,  0.],\n",
       "        [21.,  8., 39.,  1.,  0.,  0.,  0.],\n",
       "        [22.,  8., 40.,  1.,  0.,  0.,  0.],\n",
       "        [23.,  8., 41.,  1.,  0.,  0.,  0.],\n",
       "        [24.,  8., 42.,  1.,  0.,  0.,  0.],\n",
       "        [25.,  8., 43.,  1.,  0.,  0.,  0.],\n",
       "        [26.,  8., 44.,  1.,  0.,  0.,  0.],\n",
       "        [27.,  8., 45.,  1.,  0.,  0.,  0.]], dtype=float32),\n",
       " array([[0., 0., 1., 0.]], dtype=float32))"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "INDEX_TIMESTEP = 18\n",
    "sample_array[0][0][INDEX_TIMESTEP], sample_array[0][1][INDEX_TIMESTEP]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model 1: LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 215,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyper-parameters\n",
    "LSTM_UNITS = 16\n",
    "LSTM_LEARNING_RATE = 0.0001"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 278,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential_12\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " lstm_14 (LSTM)              (None, 10, 16)            1600      \n",
      "                                                                 \n",
      " dense_9 (Dense)             (None, 10, 4)             68        \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 1,668\n",
      "Trainable params: 1,668\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "# Model Definition\n",
    "lstm_model = keras.Sequential()\n",
    "lstm_model.add(layers.Input(shape=(TIME_STEP_INPUT, 8)))\n",
    "lstm_model.add(layers.LSTM(LSTM_UNITS, stateful=False, return_sequences=True))\n",
    "# lstm_model.add(layers.LSTM(LSTM_UNITS, stateful=False, return_sequences=True))\n",
    "lstm_model.add(layers.Dense(SLEEP_STAGES))\n",
    "lstm_model.build()\n",
    "print(lstm_model.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model Training\n",
    "lstm_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
    "lstm_optm = keras.optimizers.Adam(learning_rate=LSTM_LEARNING_RATE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model 2: GRU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyper-paramters\n",
    "GRU_UNITS = 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gru_model = keras.Sequential([\n",
    "    layers.GRU(GRU_UNITS),\n",
    "    layers.Dense(SLEEP_STAGES)\n",
    "])\n",
    "gru_model.add(layers.Embedding(input_dim=1000, output_dim=64))\n",
    "print(gru_model.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model 3: Attention Mechanism"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ATTENTION_UNITS = 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "am_model = keras.Sequential([\n",
    "    Attention(ATTENTION_UNITS)\n",
    "    layers.Dense(SLEEP_STAGES)\n",
    "])\n",
    "print(am_model.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model Head-to-Head testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Scratch"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "final_lab",
   "language": "python",
   "name": "final_lab"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}