Skip to content
Snippets Groups Projects
tf_model.ipynb 180 KiB
Newer Older
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Notebook for training and testing AI models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
SurajSSingh's avatar
SurajSSingh committed
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "import pandas as pd\n",
SurajSSingh's avatar
SurajSSingh committed
    "from attention import CustomAttention\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
SurajSSingh's avatar
SurajSSingh committed
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.8.0\n"
     ]
    }
   ],
   "source": [
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "code",
SurajSSingh's avatar
SurajSSingh committed
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CONSTANTS\n",
    "RAW_SLEEP_DATA_PATH = \".data/raw_bed_sleep-state.csv\"\n",
    "CLEANED_SLEEP_DATA_PATH = \".data/clean_bed_sleep-state.csv\"\n",
SurajSSingh's avatar
SurajSSingh committed
    "SLEEP_DATA_PATH = \".data/sleep_data_simple.csv\"\n",
    "UPDATED_SLEEP_DATA_PATH = \".data/updated_sleep_data.csv\""
   ]
  },
  {
   "cell_type": "code",
SurajSSingh's avatar
SurajSSingh committed
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Parameters and Hyper-parameters\n",
    "SLEEP_STAGES = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cleaning Raw Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "import datetime\n",
    "import itertools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "datetime.datetime(2022, 4, 21, 10, 19, tzinfo=datetime.timezone(datetime.timedelta(seconds=7200)))"
      ]
     },
     "execution_count": 104,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datetime.datetime.strptime(\"2022-04-21T10:18:00+02:00\",\"%Y-%m-%dT%H:%M:%S%z\") + datetime.timedelta(minutes=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [],
   "source": [
    "def stage_probability(stage, to_test):\n",
    "    return 1.0 if stage == to_test else 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "6.833333333333333"
      ]
     },
     "execution_count": 150,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "((start_time - in_bed_time).seconds)/3600"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 201,
   "metadata": {},
   "outputs": [],
   "source": [
    "cleaned_info = []\n",
    "date_seen = set()\n",
    "previous_duration = 60\n",
    "with open(RAW_SLEEP_DATA_PATH, mode ='r') as raw_file:\n",
    "    csvFile = csv.reader(raw_file)\n",
    "    # max_count = 1\n",
    "    # stuff = set()\n",
    "    in_bed_time = None\n",
    "    current_sleep_id = -1\n",
    "    for index, lines in enumerate(csvFile):\n",
    "        if index == 0:\n",
    "            cleaned_info.append([\n",
    "                \"sleep_id\",\n",
    "                \"sleep_begin\",\n",
    "                \"stage_start\",\n",
    "                \"time_since_begin_sec\",\n",
    "                \"stage_duration_sec\",\n",
    "                \"stage_end\", \n",
    "                \"stage_value\",\n",
    "                \"awake_probability\",\n",
    "                \"light_probability\",\n",
    "                \"deep_probability\",\n",
    "                \"rem_probability\",\n",
    "            ])\n",
    "            continue\n",
    "        start_time = datetime.datetime.strptime(lines[0],\"%Y-%m-%dT%H:%M:%S%z\")\n",
    "        if start_time in date_seen:\n",
    "            continue\n",
    "        date_seen.add(start_time)\n",
    "        if not in_bed_time or in_bed_time > start_time:\n",
    "            current_sleep_id += 1\n",
    "            in_bed_time = start_time\n",
    "        # for duration, stage in enumerate(\n",
    "        # for offset, (duration, stage) in enumerate(\n",
    "        #     zip(\n",
    "        #         # itertools.accumulate(lines[1].strip(\"[]\").split(\",\"), lambda x,y: int(x)+int(y)//60, initial = 0), \n",
    "        #         map(int, lines[1].strip(\"[]\").split(\",\"))\n",
    "        #         map(int, lines[2].strip(\"[]\").split(\",\"))\n",
    "        #     )\n",
    "        #         # map(int, lines[2].strip(\"[]\").split(\",\"))\n",
    "        # ):\n",
    "        for offset, (duration, stage) in enumerate(zip(map(int, lines[1].strip(\"[]\").split(\",\")), map(int, lines[2].strip(\"[]\").split(\",\")))):\n",
    "            # print(f\"{(index, subindex) = }, {duration = }, {stage = }\")\n",
    "            # print(f\"{(index, duration) = } {stage = }\")\n",
    "            current_time = start_time + datetime.timedelta(seconds=offset*previous_duration)\n",
    "            cleaned_info.append([\n",
    "                current_sleep_id,\n",
    "                in_bed_time,\n",
    "                current_time, \n",
    "                (current_time - in_bed_time).seconds,\n",
    "                duration, \n",
    "                current_time + datetime.timedelta(seconds=duration), \n",
    "                stage,\n",
    "                stage_probability(0, stage),\n",
    "                stage_probability(1, stage),\n",
    "                stage_probability(2, stage),\n",
    "                stage_probability(3, stage),\n",
    "            ])\n",
    "            previous_duration = duration\n",
    "            # print(f\"{(index, subindex) = }, {val = }\")\n",
    "        # print(list())\n",
    "        # if index >= max_count:\n",
    "        #     break\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 202,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished Writing Cleaned Data\n"
     ]
    }
   ],
   "source": [
    "with open(CLEANED_SLEEP_DATA_PATH, 'w') as clean_file:\n",
    "    write = csv.writer(clean_file)\n",
    "    write.writerows(cleaned_info)\n",
    "print(\"Finished Writing Cleaned Data\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating DataFrame from clean raw data"
   ]
  },
  {
   "cell_type": "code",
SurajSSingh's avatar
SurajSSingh committed
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the cleaned data\n",
    "sleep_df_raw = pd.read_csv(CLEANED_SLEEP_DATA_PATH)#, parse_dates=[\"start\", \"end\"], infer_datetime_format=True)"
   ]
  },
  {
   "cell_type": "code",
SurajSSingh's avatar
SurajSSingh committed
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess data: \n",
    "#   1. convert to datetime\n",
    "sleep_df_raw[\"sleep_begin\"] = pd.to_datetime(sleep_df_raw[\"sleep_begin\"], utc=True)\n",
    "sleep_df_raw[\"stage_start\"] = pd.to_datetime(sleep_df_raw[\"stage_start\"], utc=True)\n",
    "sleep_df_raw[\"stage_end\"] = pd.to_datetime(sleep_df_raw[\"stage_end\"], utc=True)\n",
    "#   2. Separate time, hour and minute\n",
    "#   MAYBE 3. smaller units: int16 or int8 "
   ]
  },
  {
   "cell_type": "code",
SurajSSingh's avatar
SurajSSingh committed
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_minute(row, index):\n",
    "    return row[index].time().minute\n",
    "\n",
    "def get_hour(row, index):\n",
    "    return row[index].time().hour"
   ]
  },
  {
   "cell_type": "code",
SurajSSingh's avatar
SurajSSingh committed
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "sleep_df_raw[\"stage_start_hour\"] = sleep_df_raw.apply (lambda row: get_hour(row, \"stage_start\"), axis=1)\n",
    "sleep_df_raw[\"stage_start_minute\"] = sleep_df_raw.apply (lambda row: get_minute(row, \"stage_start\"), axis=1)"
   ]
  },
  {
   "cell_type": "code",
SurajSSingh's avatar
SurajSSingh committed
   "execution_count": 83,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 551042 entries, 0 to 551041\n",
      "Data columns (total 13 columns):\n",
      " #   Column                Non-Null Count   Dtype              \n",
      "---  ------                --------------   -----              \n",
      " 0   sleep_id              551042 non-null  int64              \n",
      " 1   sleep_begin           551042 non-null  datetime64[ns, UTC]\n",
      " 2   stage_start           551042 non-null  datetime64[ns, UTC]\n",
      " 3   time_since_begin_sec  551042 non-null  int64              \n",
      " 4   stage_duration_sec    551042 non-null  int64              \n",
      " 5   stage_end             551042 non-null  datetime64[ns, UTC]\n",
      " 6   stage_value           551042 non-null  int64              \n",
      " 7   awake_probability     551042 non-null  float64            \n",
      " 8   light_probability     551042 non-null  float64            \n",
      " 9   deep_probability      551042 non-null  float64            \n",
      " 10  rem_probability       551042 non-null  float64            \n",
      " 11  stage_start_hour      551042 non-null  int64              \n",
      " 12  stage_start_minute    551042 non-null  int64              \n",
      "dtypes: datetime64[ns, UTC](3), float64(4), int64(6)\n",
      "memory usage: 54.7 MB\n"
     ]
    }
   ],
   "source": [
    "sleep_df_raw.info()"
   ]
  },
  {
   "cell_type": "code",
SurajSSingh's avatar
SurajSSingh committed
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sleep_id</th>\n",
       "      <th>sleep_begin</th>\n",
       "      <th>stage_start</th>\n",
       "      <th>time_since_begin_sec</th>\n",
       "      <th>stage_duration_sec</th>\n",
       "      <th>stage_end</th>\n",
       "      <th>stage_value</th>\n",
       "      <th>awake_probability</th>\n",
       "      <th>light_probability</th>\n",
       "      <th>deep_probability</th>\n",
       "      <th>rem_probability</th>\n",
       "      <th>stage_start_hour</th>\n",
       "      <th>stage_start_minute</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:19:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8</td>\n",
       "      <td>18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:19:00+00:00</td>\n",
       "      <td>60</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:20:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8</td>\n",
       "      <td>19</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:20:00+00:00</td>\n",
       "      <td>120</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:21:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8</td>\n",
       "      <td>20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:21:00+00:00</td>\n",
       "      <td>180</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:22:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8</td>\n",
       "      <td>21</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:22:00+00:00</td>\n",
       "      <td>240</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:23:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8</td>\n",
       "      <td>22</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551037</th>\n",
       "      <td>1132</td>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:17:00+00:00</td>\n",
       "      <td>25560</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:18:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>13</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551038</th>\n",
       "      <td>1132</td>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:18:00+00:00</td>\n",
       "      <td>25620</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:19:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>13</td>\n",
       "      <td>18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551039</th>\n",
       "      <td>1132</td>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:19:00+00:00</td>\n",
       "      <td>25680</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:20:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>13</td>\n",
       "      <td>19</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551040</th>\n",
       "      <td>1132</td>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:20:00+00:00</td>\n",
       "      <td>25740</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:21:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>13</td>\n",
       "      <td>20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551041</th>\n",
       "      <td>1132</td>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:21:00+00:00</td>\n",
       "      <td>25800</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:22:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>13</td>\n",
       "      <td>21</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>551042 rows × 13 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        sleep_id               sleep_begin               stage_start  \\\n",
       "0              0 2022-04-21 08:18:00+00:00 2022-04-21 08:18:00+00:00   \n",
       "1              0 2022-04-21 08:18:00+00:00 2022-04-21 08:19:00+00:00   \n",
       "2              0 2022-04-21 08:18:00+00:00 2022-04-21 08:20:00+00:00   \n",
       "3              0 2022-04-21 08:18:00+00:00 2022-04-21 08:21:00+00:00   \n",
       "4              0 2022-04-21 08:18:00+00:00 2022-04-21 08:22:00+00:00   \n",
       "...          ...                       ...                       ...   \n",
       "551037      1132 2019-02-11 06:11:00+00:00 2019-02-11 13:17:00+00:00   \n",
       "551038      1132 2019-02-11 06:11:00+00:00 2019-02-11 13:18:00+00:00   \n",
       "551039      1132 2019-02-11 06:11:00+00:00 2019-02-11 13:19:00+00:00   \n",
       "551040      1132 2019-02-11 06:11:00+00:00 2019-02-11 13:20:00+00:00   \n",
       "551041      1132 2019-02-11 06:11:00+00:00 2019-02-11 13:21:00+00:00   \n",
       "\n",
       "        time_since_begin_sec  stage_duration_sec                 stage_end  \\\n",
       "0                          0                  60 2022-04-21 08:19:00+00:00   \n",
       "1                         60                  60 2022-04-21 08:20:00+00:00   \n",
       "2                        120                  60 2022-04-21 08:21:00+00:00   \n",
       "3                        180                  60 2022-04-21 08:22:00+00:00   \n",
       "4                        240                  60 2022-04-21 08:23:00+00:00   \n",
       "...                      ...                 ...                       ...   \n",
       "551037                 25560                  60 2019-02-11 13:18:00+00:00   \n",
       "551038                 25620                  60 2019-02-11 13:19:00+00:00   \n",
       "551039                 25680                  60 2019-02-11 13:20:00+00:00   \n",
       "551040                 25740                  60 2019-02-11 13:21:00+00:00   \n",
       "551041                 25800                  60 2019-02-11 13:22:00+00:00   \n",
       "\n",
       "        stage_value  awake_probability  light_probability  deep_probability  \\\n",
       "0                 0                1.0                0.0               0.0   \n",
       "1                 0                1.0                0.0               0.0   \n",
       "2                 0                1.0                0.0               0.0   \n",
       "3                 0                1.0                0.0               0.0   \n",
       "4                 0                1.0                0.0               0.0   \n",
       "...             ...                ...                ...               ...   \n",
       "551037            1                0.0                1.0               0.0   \n",
       "551038            1                0.0                1.0               0.0   \n",
       "551039            1                0.0                1.0               0.0   \n",
       "551040            1                0.0                1.0               0.0   \n",
       "551041            1                0.0                1.0               0.0   \n",
       "\n",
       "        rem_probability  stage_start_hour  stage_start_minute  \n",
       "0                   0.0                 8                  18  \n",
       "1                   0.0                 8                  19  \n",
       "2                   0.0                 8                  20  \n",
       "3                   0.0                 8                  21  \n",
       "4                   0.0                 8                  22  \n",
       "...                 ...               ...                 ...  \n",
       "551037              0.0                13                  17  \n",
       "551038              0.0                13                  18  \n",
       "551039              0.0                13                  19  \n",
       "551040              0.0                13                  20  \n",
       "551041              0.0                13                  21  \n",
       "\n",
       "[551042 rows x 13 columns]"
SurajSSingh's avatar
SurajSSingh committed
     "execution_count": 84,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sleep_df_raw"
   ]
  },
  {
   "cell_type": "code",
SurajSSingh's avatar
SurajSSingh committed
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "sleep_data = sleep_df_raw[[\"sleep_id\", \"stage_start_hour\", \"stage_start_minute\", \"awake_probability\", \"rem_probability\",\"light_probability\", \"deep_probability\"]]\n",
    "sleep_data.insert(loc=1, column=\"minutes_since_begin\" , value= sleep_df_raw[\"time_since_begin_sec\"]//60)"
   ]
  },
  {
   "cell_type": "code",
SurajSSingh's avatar
SurajSSingh committed
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   sleep_id  minutes_since_begin  stage_start_hour  stage_start_minute  \\\n",
      "0         0                    0                 8                  18   \n",
      "1         0                    1                 8                  19   \n",
      "2         0                    2                 8                  20   \n",
      "3         0                    3                 8                  21   \n",
      "4         0                    4                 8                  22   \n",
      "\n",
      "   awake_probability  rem_probability  light_probability  deep_probability  \n",
      "0                1.0              0.0                0.0               0.0  \n",
      "1                1.0              0.0                0.0               0.0  \n",
      "2                1.0              0.0                0.0               0.0  \n",
      "3                1.0              0.0                0.0               0.0  \n",
      "4                1.0              0.0                0.0               0.0  \n",
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 551042 entries, 0 to 551041\n",
      "Data columns (total 8 columns):\n",
      " #   Column               Non-Null Count   Dtype  \n",
      "---  ------               --------------   -----  \n",
      " 0   sleep_id             551042 non-null  int64  \n",
      " 1   minutes_since_begin  551042 non-null  int64  \n",
      " 2   stage_start_hour     551042 non-null  int64  \n",
      " 3   stage_start_minute   551042 non-null  int64  \n",
      " 4   awake_probability    551042 non-null  float64\n",
      " 5   rem_probability      551042 non-null  float64\n",
      " 6   light_probability    551042 non-null  float64\n",
      " 7   deep_probability     551042 non-null  float64\n",
      "dtypes: float64(4), int64(4)\n",
      "memory usage: 33.6 MB\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "print(sleep_data.head())\n",
    "print(sleep_data.info())"
   ]
  },
  {
   "cell_type": "code",
SurajSSingh's avatar
SurajSSingh committed
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "sleep_data.to_csv(\".data/sleep_data_simple.csv\", index=False, index_label=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Development"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
SurajSSingh's avatar
SurajSSingh committed
   "execution_count": 5,
SurajSSingh's avatar
SurajSSingh committed
   "metadata": {},
   "outputs": [],
   "source": [
SurajSSingh's avatar
SurajSSingh committed
    "TEST_SIZE = 365//2\n",
    "VALIDATION_SIZE = 365\n",
SurajSSingh's avatar
SurajSSingh committed
    "\n",
    "BATCH_SIZE = 64\n",
SurajSSingh's avatar
SurajSSingh committed
    "INPUT_TIME_STEP = 5 # in minutes\n",
SurajSSingh's avatar
SurajSSingh committed
    "INPUT_FEATURES_SIZE = 7\n",
    "MAX_EPOCHS = 20"
   ]
  },
  {
   "cell_type": "code",
SurajSSingh's avatar
SurajSSingh committed
   "execution_count": 6,
SurajSSingh's avatar
SurajSSingh committed
   "metadata": {},
   "outputs": [],
   "source": [
SurajSSingh's avatar
SurajSSingh committed
    "SAMPLE_COUNT = 1000\n",
    "#   A R L D\n",
    "# A\n",
    "# R\n",
    "# L\n",
    "# D\n",
    "CONFUSION_MATRIX = np.array(\n",
    "    [\n",
    "        [66.2, 5.0, 22.5, 6.2],\n",
    "        [1.6, 60.7, 33.0, 4.7],\n",
    "        [3.8, 22.3, 55.4, 18.5],\n",
    "        [0.0, 1.3, 26.7, 72.0],\n",
    "    ]\n",
    ")/100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import Data"
SurajSSingh's avatar
SurajSSingh committed
   ]
  },
  {
   "cell_type": "code",
SurajSSingh's avatar
SurajSSingh committed
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sleep_data = pd.read_csv(SLEEP_DATA_PATH)\n",
    "sleep_data = pd.read_csv(UPDATED_SLEEP_DATA_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sleep_id</th>\n",
       "      <th>minutes_since_begin</th>\n",
       "      <th>stage_start_hour</th>\n",
       "      <th>stage_start_minute</th>\n",
SurajSSingh's avatar
SurajSSingh committed
       "      <th>awake_probability_noisy</th>\n",
       "      <th>rem_probability_noisy</th>\n",
       "      <th>light_probability_noisy</th>\n",
       "      <th>deep_probability_noisy</th>\n",
       "      <th>awake_probability_original</th>\n",
       "      <th>rem_probability_original</th>\n",
       "      <th>light_probability_original</th>\n",
       "      <th>deep_probability_original</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
SurajSSingh's avatar
SurajSSingh committed
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>18.0</td>\n",
       "      <td>0.680</td>\n",
       "      <td>0.057</td>\n",
       "      <td>0.200</td>\n",
       "      <td>0.063</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
SurajSSingh's avatar
SurajSSingh committed
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>0.652</td>\n",
       "      <td>0.060</td>\n",
       "      <td>0.224</td>\n",
       "      <td>0.064</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
SurajSSingh's avatar
SurajSSingh committed
       "      <td>0.0</td>\n",
       "      <td>2.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>20.0</td>\n",
       "      <td>0.672</td>\n",
       "      <td>0.059</td>\n",
       "      <td>0.209</td>\n",
       "      <td>0.060</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
SurajSSingh's avatar
SurajSSingh committed
       "      <td>0.0</td>\n",
       "      <td>3.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>21.0</td>\n",
       "      <td>0.645</td>\n",
       "      <td>0.056</td>\n",
       "      <td>0.235</td>\n",
       "      <td>0.064</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
SurajSSingh's avatar
SurajSSingh committed
       "      <td>0.0</td>\n",
       "      <td>4.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>0.644</td>\n",
       "      <td>0.054</td>\n",
       "      <td>0.244</td>\n",
       "      <td>0.058</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
SurajSSingh's avatar
SurajSSingh committed
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551037</th>\n",
SurajSSingh's avatar
SurajSSingh committed
       "      <td>1132.0</td>\n",
       "      <td>426.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>17.0</td>\n",
       "      <td>0.041</td>\n",
       "      <td>0.193</td>\n",
       "      <td>0.576</td>\n",
       "      <td>0.190</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551038</th>\n",
SurajSSingh's avatar
SurajSSingh committed
       "      <td>1132.0</td>\n",
       "      <td>427.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>18.0</td>\n",
       "      <td>0.027</td>\n",
       "      <td>0.209</td>\n",
       "      <td>0.563</td>\n",
       "      <td>0.201</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551039</th>\n",
SurajSSingh's avatar
SurajSSingh committed
       "      <td>1132.0</td>\n",
       "      <td>428.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>19.0</td>\n",
       "      <td>0.032</td>\n",
       "      <td>0.220</td>\n",
       "      <td>0.574</td>\n",
       "      <td>0.174</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551040</th>\n",
SurajSSingh's avatar
SurajSSingh committed
       "      <td>1132.0</td>\n",
       "      <td>429.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>20.0</td>\n",
       "      <td>0.036</td>\n",
       "      <td>0.256</td>\n",
       "      <td>0.530</td>\n",
       "      <td>0.178</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551041</th>\n",
SurajSSingh's avatar
SurajSSingh committed
       "      <td>1132.0</td>\n",
       "      <td>430.0</td>\n",
       "      <td>13.0</td>\n",
       "      <td>21.0</td>\n",
       "      <td>0.033</td>\n",
       "      <td>0.205</td>\n",
       "      <td>0.571</td>\n",
       "      <td>0.191</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
SurajSSingh's avatar
SurajSSingh committed
       "<p>551042 rows × 12 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        sleep_id  minutes_since_begin  stage_start_hour  stage_start_minute  \\\n",
SurajSSingh's avatar
SurajSSingh committed
       "0            0.0                  0.0               8.0                18.0   \n",
       "1            0.0                  1.0               8.0                19.0   \n",
       "2            0.0                  2.0               8.0                20.0   \n",
       "3            0.0                  3.0               8.0                21.0   \n",
       "4            0.0                  4.0               8.0                22.0   \n",
       "...          ...                  ...               ...                 ...   \n",
SurajSSingh's avatar
SurajSSingh committed
       "551037    1132.0                426.0              13.0                17.0   \n",
       "551038    1132.0                427.0              13.0                18.0   \n",
       "551039    1132.0                428.0              13.0                19.0   \n",
       "551040    1132.0                429.0              13.0                20.0   \n",
       "551041    1132.0                430.0              13.0                21.0   \n",
       "\n",
       "        awake_probability_noisy  rem_probability_noisy  \\\n",
       "0                         0.680                  0.057   \n",
       "1                         0.652                  0.060   \n",
       "2                         0.672                  0.059   \n",
       "3                         0.645                  0.056   \n",
       "4                         0.644                  0.054   \n",
       "...                         ...                    ...   \n",
       "551037                    0.041                  0.193   \n",
       "551038                    0.027                  0.209   \n",
       "551039                    0.032                  0.220   \n",
       "551040                    0.036                  0.256   \n",
       "551041                    0.033                  0.205   \n",
       "\n",
       "        light_probability_noisy  deep_probability_noisy  \\\n",
       "0                         0.200                   0.063   \n",
       "1                         0.224                   0.064   \n",
       "2                         0.209                   0.060   \n",
       "3                         0.235                   0.064   \n",
       "4                         0.244                   0.058   \n",
       "...                         ...                     ...   \n",
       "551037                    0.576                   0.190   \n",
       "551038                    0.563                   0.201   \n",
       "551039                    0.574                   0.174   \n",
       "551040                    0.530                   0.178   \n",
       "551041                    0.571                   0.191   \n",