{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Notebook for training and testing AI models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import numpy as np\n", "from tensorflow import keras\n", "from tensorflow.keras import layers\n", "import pandas as pd\n", "from attention import Attention" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.8.0\n" ] } ], "source": [ "print(tf.__version__)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# CONSTANTS\n", "RAW_SLEEP_DATA_PATH = \".data/raw_bed_sleep-state.csv\"\n", "CLEANED_SLEEP_DATA_PATH = \".data/clean_bed_sleep-state.csv\"" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "## Parameters and Hyper-parameters\n", "SLEEP_STAGES = 4" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Cleaning Raw Data" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "import csv\n", "import datetime\n", "import itertools" ] }, { "cell_type": "code", "execution_count": 104, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "datetime.datetime(2022, 4, 21, 10, 19, tzinfo=datetime.timezone(datetime.timedelta(seconds=7200)))" ] }, "execution_count": 104, "metadata": {}, "output_type": "execute_result" } ], "source": [ "datetime.datetime.strptime(\"2022-04-21T10:18:00+02:00\",\"%Y-%m-%dT%H:%M:%S%z\") + datetime.timedelta(minutes=1)" ] }, { "cell_type": "code", "execution_count": 115, "metadata": {}, "outputs": [], "source": [ "def stage_probability(stage, to_test):\n", " return 1.0 if stage == to_test else 0.0" ] }, { "cell_type": "code", "execution_count": 150, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "6.833333333333333" ] }, "execution_count": 150, "metadata": {}, "output_type": "execute_result" } ], "source": [ "((start_time - in_bed_time).seconds)/3600" ] }, { "cell_type": "code", "execution_count": 201, "metadata": {}, "outputs": [], "source": [ "cleaned_info = []\n", "date_seen = set()\n", "previous_duration = 60\n", "with open(RAW_SLEEP_DATA_PATH, mode ='r') as raw_file:\n", " csvFile = csv.reader(raw_file)\n", " # max_count = 1\n", " # stuff = set()\n", " in_bed_time = None\n", " current_sleep_id = -1\n", " for index, lines in enumerate(csvFile):\n", " if index == 0:\n", " cleaned_info.append([\n", " \"sleep_id\",\n", " \"sleep_begin\",\n", " \"stage_start\",\n", " \"time_since_begin_sec\",\n", " \"stage_duration_sec\",\n", " \"stage_end\", \n", " \"stage_value\",\n", " \"awake_probability\",\n", " \"light_probability\",\n", " \"deep_probability\",\n", " \"rem_probability\",\n", " ])\n", " continue\n", " start_time = datetime.datetime.strptime(lines[0],\"%Y-%m-%dT%H:%M:%S%z\")\n", " if start_time in date_seen:\n", " continue\n", " date_seen.add(start_time)\n", " if not in_bed_time or in_bed_time > start_time:\n", " current_sleep_id += 1\n", " in_bed_time = start_time\n", " # for duration, stage in enumerate(\n", " # for offset, (duration, stage) in enumerate(\n", " # zip(\n", " # # itertools.accumulate(lines[1].strip(\"[]\").split(\",\"), lambda x,y: int(x)+int(y)//60, initial = 0), \n", " # map(int, lines[1].strip(\"[]\").split(\",\"))\n", " # map(int, lines[2].strip(\"[]\").split(\",\"))\n", " # )\n", " # # map(int, lines[2].strip(\"[]\").split(\",\"))\n", " # ):\n", " for offset, (duration, stage) in enumerate(zip(map(int, lines[1].strip(\"[]\").split(\",\")), map(int, lines[2].strip(\"[]\").split(\",\")))):\n", " # print(f\"{(index, subindex) = }, {duration = }, {stage = }\")\n", " # print(f\"{(index, duration) = } {stage = }\")\n", " current_time = start_time + datetime.timedelta(seconds=offset*previous_duration)\n", " cleaned_info.append([\n", " current_sleep_id,\n", " in_bed_time,\n", " current_time, \n", " (current_time - in_bed_time).seconds,\n", " duration, \n", " current_time + datetime.timedelta(seconds=duration), \n", " stage,\n", " stage_probability(0, stage),\n", " stage_probability(1, stage),\n", " stage_probability(2, stage),\n", " stage_probability(3, stage),\n", " ])\n", " previous_duration = duration\n", " # print(f\"{(index, subindex) = }, {val = }\")\n", " # print(list())\n", " # if index >= max_count:\n", " # break\n" ] }, { "cell_type": "code", "execution_count": 202, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Finished Writing Cleaned Data\n" ] } ], "source": [ "with open(CLEANED_SLEEP_DATA_PATH, 'w') as clean_file:\n", " write = csv.writer(clean_file)\n", " write.writerows(cleaned_info)\n", "print(\"Finished Writing Cleaned Data\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Creating DataFrame from clean raw data" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Get the cleaned data\n", "sleep_df_raw = pd.read_csv(CLEANED_SLEEP_DATA_PATH)#, parse_dates=[\"start\", \"end\"], infer_datetime_format=True)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Preprocess data: \n", "# 1. convert to datetime\n", "sleep_df_raw[\"sleep_begin\"] = pd.to_datetime(sleep_df_raw[\"sleep_begin\"], utc=True)\n", "sleep_df_raw[\"stage_start\"] = pd.to_datetime(sleep_df_raw[\"stage_start\"], utc=True)\n", "sleep_df_raw[\"stage_end\"] = pd.to_datetime(sleep_df_raw[\"stage_end\"], utc=True)\n", "# 2. Separate time, hour and minute\n", "# MAYBE 3. smaller units: int16 or int8 " ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def get_minute(row, index):\n", " return row[index].time().minute\n", "\n", "def get_hour(row, index):\n", " return row[index].time().hour" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "sleep_df_raw[\"stage_start_hour\"] = sleep_df_raw.apply (lambda row: get_hour(row, \"stage_start\"), axis=1)\n", "sleep_df_raw[\"stage_start_minute\"] = sleep_df_raw.apply (lambda row: get_minute(row, \"stage_start\"), axis=1)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<class 'pandas.core.frame.DataFrame'>\n", "RangeIndex: 551042 entries, 0 to 551041\n", "Data columns (total 13 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 sleep_id 551042 non-null int64 \n", " 1 sleep_begin 551042 non-null datetime64[ns, UTC]\n", " 2 stage_start 551042 non-null datetime64[ns, UTC]\n", " 3 time_since_begin_sec 551042 non-null int64 \n", " 4 stage_duration_sec 551042 non-null int64 \n", " 5 stage_end 551042 non-null datetime64[ns, UTC]\n", " 6 stage_value 551042 non-null int64 \n", " 7 awake_probability 551042 non-null float64 \n", " 8 light_probability 551042 non-null float64 \n", " 9 deep_probability 551042 non-null float64 \n", " 10 rem_probability 551042 non-null float64 \n", " 11 stage_start_hour 551042 non-null int64 \n", " 12 stage_start_minute 551042 non-null int64 \n", "dtypes: datetime64[ns, UTC](3), float64(4), int64(6)\n", "memory usage: 54.7 MB\n" ] } ], "source": [ "sleep_df_raw.info()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>sleep_id</th>\n", " <th>sleep_begin</th>\n", " <th>stage_start</th>\n", " <th>time_since_begin_sec</th>\n", " <th>stage_duration_sec</th>\n", " <th>stage_end</th>\n", " <th>stage_value</th>\n", " <th>awake_probability</th>\n", " <th>light_probability</th>\n", " <th>deep_probability</th>\n", " <th>rem_probability</th>\n", " <th>stage_start_hour</th>\n", " <th>stage_start_minute</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0</td>\n", " <td>2022-04-21 08:18:00+00:00</td>\n", " <td>2022-04-21 08:18:00+00:00</td>\n", " <td>0</td>\n", " <td>60</td>\n", " <td>2022-04-21 08:19:00+00:00</td>\n", " <td>0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>8</td>\n", " <td>18</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>0</td>\n", " <td>2022-04-21 08:18:00+00:00</td>\n", " <td>2022-04-21 08:19:00+00:00</td>\n", " <td>60</td>\n", " <td>60</td>\n", " <td>2022-04-21 08:20:00+00:00</td>\n", " <td>0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>8</td>\n", " <td>19</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>0</td>\n", " <td>2022-04-21 08:18:00+00:00</td>\n", " <td>2022-04-21 08:20:00+00:00</td>\n", " <td>120</td>\n", " <td>60</td>\n", " <td>2022-04-21 08:21:00+00:00</td>\n", " <td>0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>8</td>\n", " <td>20</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>0</td>\n", " <td>2022-04-21 08:18:00+00:00</td>\n", " <td>2022-04-21 08:21:00+00:00</td>\n", " <td>180</td>\n", " <td>60</td>\n", " <td>2022-04-21 08:22:00+00:00</td>\n", " <td>0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>8</td>\n", " <td>21</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>0</td>\n", " <td>2022-04-21 08:18:00+00:00</td>\n", " <td>2022-04-21 08:22:00+00:00</td>\n", " <td>240</td>\n", " <td>60</td>\n", " <td>2022-04-21 08:23:00+00:00</td>\n", " <td>0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>8</td>\n", " <td>22</td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>551037</th>\n", " <td>1132</td>\n", " <td>2019-02-11 06:11:00+00:00</td>\n", " <td>2019-02-11 13:17:00+00:00</td>\n", " <td>25560</td>\n", " <td>60</td>\n", " <td>2019-02-11 13:18:00+00:00</td>\n", " <td>1</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>13</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>551038</th>\n", " <td>1132</td>\n", " <td>2019-02-11 06:11:00+00:00</td>\n", " <td>2019-02-11 13:18:00+00:00</td>\n", " <td>25620</td>\n", " <td>60</td>\n", " <td>2019-02-11 13:19:00+00:00</td>\n", " <td>1</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>13</td>\n", " <td>18</td>\n", " </tr>\n", " <tr>\n", " <th>551039</th>\n", " <td>1132</td>\n", " <td>2019-02-11 06:11:00+00:00</td>\n", " <td>2019-02-11 13:19:00+00:00</td>\n", " <td>25680</td>\n", " <td>60</td>\n", " <td>2019-02-11 13:20:00+00:00</td>\n", " <td>1</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>13</td>\n", " <td>19</td>\n", " </tr>\n", " <tr>\n", " <th>551040</th>\n", " <td>1132</td>\n", " <td>2019-02-11 06:11:00+00:00</td>\n", " <td>2019-02-11 13:20:00+00:00</td>\n", " <td>25740</td>\n", " <td>60</td>\n", " <td>2019-02-11 13:21:00+00:00</td>\n", " <td>1</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>13</td>\n", " <td>20</td>\n", " </tr>\n", " <tr>\n", " <th>551041</th>\n", " <td>1132</td>\n", " <td>2019-02-11 06:11:00+00:00</td>\n", " <td>2019-02-11 13:21:00+00:00</td>\n", " <td>25800</td>\n", " <td>60</td>\n", " <td>2019-02-11 13:22:00+00:00</td>\n", " <td>1</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>13</td>\n", " <td>21</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>551042 rows × 13 columns</p>\n", "</div>" ], "text/plain": [ " sleep_id sleep_begin stage_start \\\n", "0 0 2022-04-21 08:18:00+00:00 2022-04-21 08:18:00+00:00 \n", "1 0 2022-04-21 08:18:00+00:00 2022-04-21 08:19:00+00:00 \n", "2 0 2022-04-21 08:18:00+00:00 2022-04-21 08:20:00+00:00 \n", "3 0 2022-04-21 08:18:00+00:00 2022-04-21 08:21:00+00:00 \n", "4 0 2022-04-21 08:18:00+00:00 2022-04-21 08:22:00+00:00 \n", "... ... ... ... \n", "551037 1132 2019-02-11 06:11:00+00:00 2019-02-11 13:17:00+00:00 \n", "551038 1132 2019-02-11 06:11:00+00:00 2019-02-11 13:18:00+00:00 \n", "551039 1132 2019-02-11 06:11:00+00:00 2019-02-11 13:19:00+00:00 \n", "551040 1132 2019-02-11 06:11:00+00:00 2019-02-11 13:20:00+00:00 \n", "551041 1132 2019-02-11 06:11:00+00:00 2019-02-11 13:21:00+00:00 \n", "\n", " time_since_begin_sec stage_duration_sec stage_end \\\n", "0 0 60 2022-04-21 08:19:00+00:00 \n", "1 60 60 2022-04-21 08:20:00+00:00 \n", "2 120 60 2022-04-21 08:21:00+00:00 \n", "3 180 60 2022-04-21 08:22:00+00:00 \n", "4 240 60 2022-04-21 08:23:00+00:00 \n", "... ... ... ... \n", "551037 25560 60 2019-02-11 13:18:00+00:00 \n", "551038 25620 60 2019-02-11 13:19:00+00:00 \n", "551039 25680 60 2019-02-11 13:20:00+00:00 \n", "551040 25740 60 2019-02-11 13:21:00+00:00 \n", "551041 25800 60 2019-02-11 13:22:00+00:00 \n", "\n", " stage_value awake_probability light_probability deep_probability \\\n", "0 0 1.0 0.0 0.0 \n", "1 0 1.0 0.0 0.0 \n", "2 0 1.0 0.0 0.0 \n", "3 0 1.0 0.0 0.0 \n", "4 0 1.0 0.0 0.0 \n", "... ... ... ... ... \n", "551037 1 0.0 1.0 0.0 \n", "551038 1 0.0 1.0 0.0 \n", "551039 1 0.0 1.0 0.0 \n", "551040 1 0.0 1.0 0.0 \n", "551041 1 0.0 1.0 0.0 \n", "\n", " rem_probability stage_start_hour stage_start_minute \n", "0 0.0 8 18 \n", "1 0.0 8 19 \n", "2 0.0 8 20 \n", "3 0.0 8 21 \n", "4 0.0 8 22 \n", "... ... ... ... \n", "551037 0.0 13 17 \n", "551038 0.0 13 18 \n", "551039 0.0 13 19 \n", "551040 0.0 13 20 \n", "551041 0.0 13 21 \n", "\n", "[551042 rows x 13 columns]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sleep_df_raw" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "sleep_data = sleep_df_raw[[\"sleep_id\", \"stage_start_hour\", \"stage_start_minute\", \"awake_probability\", \"rem_probability\",\"light_probability\", \"deep_probability\"]]\n", "sleep_data.insert(loc=1, column=\"minutes_since_begin\" , value= sleep_df_raw[\"time_since_begin_sec\"]//60)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " sleep_id minutes_since_begin stage_start_hour stage_start_minute \\\n", "0 0 0 8 18 \n", "1 0 1 8 19 \n", "2 0 2 8 20 \n", "3 0 3 8 21 \n", "4 0 4 8 22 \n", "\n", " awake_probability rem_probability light_probability deep_probability \n", "0 1.0 0.0 0.0 0.0 \n", "1 1.0 0.0 0.0 0.0 \n", "2 1.0 0.0 0.0 0.0 \n", "3 1.0 0.0 0.0 0.0 \n", "4 1.0 0.0 0.0 0.0 \n", "<class 'pandas.core.frame.DataFrame'>\n", "RangeIndex: 551042 entries, 0 to 551041\n", "Data columns (total 8 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 sleep_id 551042 non-null int64 \n", " 1 minutes_since_begin 551042 non-null int64 \n", " 2 stage_start_hour 551042 non-null int64 \n", " 3 stage_start_minute 551042 non-null int64 \n", " 4 awake_probability 551042 non-null float64\n", " 5 rem_probability 551042 non-null float64\n", " 6 light_probability 551042 non-null float64\n", " 7 deep_probability 551042 non-null float64\n", "dtypes: float64(4), int64(4)\n", "memory usage: 33.6 MB\n", "None\n" ] } ], "source": [ "print(sleep_data.head())\n", "print(sleep_data.info())" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>sleep_id</th>\n", " <th>minutes_since_begin</th>\n", " <th>stage_start_hour</th>\n", " <th>stage_start_minute</th>\n", " <th>awake_probability</th>\n", " <th>rem_probability</th>\n", " <th>light_probability</th>\n", " <th>deep_probability</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>8</td>\n", " <td>18</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>8</td>\n", " <td>19</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>0</td>\n", " <td>2</td>\n", " <td>8</td>\n", " <td>20</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>0</td>\n", " <td>3</td>\n", " <td>8</td>\n", " <td>21</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>0</td>\n", " <td>4</td>\n", " <td>8</td>\n", " <td>22</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>551037</th>\n", " <td>1132</td>\n", " <td>426</td>\n", " <td>13</td>\n", " <td>17</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>551038</th>\n", " <td>1132</td>\n", " <td>427</td>\n", " <td>13</td>\n", " <td>18</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>551039</th>\n", " <td>1132</td>\n", " <td>428</td>\n", " <td>13</td>\n", " <td>19</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>551040</th>\n", " <td>1132</td>\n", " <td>429</td>\n", " <td>13</td>\n", " <td>20</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>551041</th>\n", " <td>1132</td>\n", " <td>430</td>\n", " <td>13</td>\n", " <td>21</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>551042 rows × 8 columns</p>\n", "</div>" ], "text/plain": [ " sleep_id minutes_since_begin stage_start_hour stage_start_minute \\\n", "0 0 0 8 18 \n", "1 0 1 8 19 \n", "2 0 2 8 20 \n", "3 0 3 8 21 \n", "4 0 4 8 22 \n", "... ... ... ... ... \n", "551037 1132 426 13 17 \n", "551038 1132 427 13 18 \n", "551039 1132 428 13 19 \n", "551040 1132 429 13 20 \n", "551041 1132 430 13 21 \n", "\n", " awake_probability rem_probability light_probability \\\n", "0 1.0 0.0 0.0 \n", "1 1.0 0.0 0.0 \n", "2 1.0 0.0 0.0 \n", "3 1.0 0.0 0.0 \n", "4 1.0 0.0 0.0 \n", "... ... ... ... \n", "551037 0.0 0.0 1.0 \n", "551038 0.0 0.0 1.0 \n", "551039 0.0 0.0 1.0 \n", "551040 0.0 0.0 1.0 \n", "551041 0.0 0.0 1.0 \n", "\n", " deep_probability \n", "0 0.0 \n", "1 0.0 \n", "2 0.0 \n", "3 0.0 \n", "4 0.0 \n", "... ... \n", "551037 0.0 \n", "551038 0.0 \n", "551039 0.0 \n", "551040 0.0 \n", "551041 0.0 \n", "\n", "[551042 rows x 8 columns]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sleep_data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Development" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Helper functions and class" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "TEST_SIZE = 122\n", "VALIDATION_SIZE = 183\n", "TIME_STEP_INPUT = 10 # in minutes\n", "BATCH_SIZE = 32" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "def training_test_split_by_unique_index(data, index: str, test_size: int = 10):\n", " test_ids = np.random.choice(data[index].unique(), size = test_size, replace=False)\n", " return data[~data[index].isin(test_ids)], data[data[index].isin(test_ids)]" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "# Adapted from https://www.tensorflow.org/tutorials/structured_data/time_series\n", "class WindowGenerator():\n", " def __init__(self, data, index: str = \"sleep_id\", input_width: int = TIME_STEP_INPUT, validation_size: int = VALIDATION_SIZE, test_size: int = TEST_SIZE, input_feature_slice: slice = slice(1,100), label_feature_slice: slice = slice(-4,100), generate_data_now: bool = True):\n", " # Partition data\n", " self.training, self.testing = training_test_split_by_unique_index(sleep_data, index, test_size)\n", " self.training, self.validation = training_test_split_by_unique_index(self.training, index, validation_size)\n", "\n", " # Window paramters\n", " self.input_width = input_width\n", " self.label_width = 1\n", " self.shift = 1\n", "\n", " self.total_window_size = self.input_width + self.shift\n", "\n", " self.input_slice = slice(0, input_width)\n", " self.input_indices = np.arange(self.total_window_size)[self.input_slice]\n", "\n", " self.label_start = self.total_window_size - self.label_width\n", " self.labels_slice = slice(self.label_start, None)\n", " self.label_indices = np.arange(self.total_window_size)[self.labels_slice]\n", "\n", " self.input_feature_slice = input_feature_slice\n", " self.label_feature_slice = label_feature_slice\n", "\n", " self.sample_ds = self.make_dataset(sleep_data[sleep_data[index] == 0])\n", "\n", " if generate_data_now:\n", " self.training_ds = self.make_dataset(self.training)\n", " self.validation_ds = self.make_dataset(self.validation)\n", " self.testing_ds = self.make_dataset(self.testing)\n", "\n", "\n", " def __repr__(self):\n", " return \"WindowGenerator:\\n\\t\" +'\\n\\t'.join([\n", " f'Total window size: {self.total_window_size}',\n", " f'Input indices: {self.input_indices}',\n", " f'Label indices: {self.label_indices}',\n", " ])\n", "\n", " def split_window(self, features):\n", " inputs = features[:, self.input_slice, self.input_feature_slice]\n", " labels = features[:, self.labels_slice, self.label_feature_slice]\n", " inputs.set_shape([None, self.input_width, None])\n", " labels.set_shape([None, self.label_width, None])\n", " return inputs, labels\n", "\n", " def make_dataset(self, data, index_group: str = \"sleep_id\", sort_by: str = \"minutes_since_begin\"):\n", " ds_all = None\n", " for i_group in data[index_group].unique():\n", " subset_data = np.array(data[data[index_group] == i_group].sort_values(by=[sort_by]), dtype=np.float32)\n", " ds = tf.keras.utils.timeseries_dataset_from_array(\n", " data=subset_data,\n", " targets=None,\n", " sequence_length=self.total_window_size,\n", " sequence_stride=1,\n", " shuffle=False,\n", " batch_size=BATCH_SIZE,)\n", " ds_all = ds if ds_all is None else ds_all.concatenate(ds)\n", " ds_all = ds_all.map(self.split_window)\n", "\n", " return ds_all\n", "\n", " # def generate_all_datasets(self):\n", " # self._training_ds = self.make_dataset(self.training)\n", " # self._validation_ds = self.make_dataset(self.validation)\n", " # self._testing_ds = self.make_dataset(self.testing)\n", "\n", " # def training_dataset(self):\n", " # if self._training_ds is None:\n", " # self._training_ds = self.make_dataset(self.training)\n", " # return self._training_ds\n", "\n", " # def validation_dataset(self):\n", " # if self._validation_ds is None:\n", " # self._validation_ds = self.make_dataset(self.validation)\n", " # return self._validation_ds\n", "\n", " # def test_dataset(self):\n", " # if self._testing_ds is None:\n", " # self._testing_ds = self.make_dataset(self.testing)\n", " # return self._testing_ds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data Prep\n", "\n", "All inputs follow: (batch_size, timesteps, input_dim)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "WindowGenerator:\n", "\tTotal window size: 11\n", "\tInput indices: [0 1 2 3 4 5 6 7 8 9]\n", "\tLabel indices: [10]" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "wg = WindowGenerator(sleep_data)\n", "wg" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "sample = wg.sample_ds.take(1)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "sample_array = list(sample.as_numpy_iterator())" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([[18., 8., 36., 1., 0., 0., 0.],\n", " [19., 8., 37., 1., 0., 0., 0.],\n", " [20., 8., 38., 1., 0., 0., 0.],\n", " [21., 8., 39., 1., 0., 0., 0.],\n", " [22., 8., 40., 1., 0., 0., 0.],\n", " [23., 8., 41., 1., 0., 0., 0.],\n", " [24., 8., 42., 1., 0., 0., 0.],\n", " [25., 8., 43., 1., 0., 0., 0.],\n", " [26., 8., 44., 1., 0., 0., 0.],\n", " [27., 8., 45., 1., 0., 0., 0.]], dtype=float32),\n", " array([[0., 0., 1., 0.]], dtype=float32))" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "INDEX_TIMESTEP = 18\n", "sample_array[0][0][INDEX_TIMESTEP], sample_array[0][1][INDEX_TIMESTEP]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model 1: LSTM" ] }, { "cell_type": "code", "execution_count": 215, "metadata": {}, "outputs": [], "source": [ "# Hyper-parameters\n", "LSTM_UNITS = 16\n", "LSTM_LEARNING_RATE = 0.0001" ] }, { "cell_type": "code", "execution_count": 278, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_12\"\n", "_________________________________________________________________\n", " Layer (type) Output Shape Param # \n", "=================================================================\n", " lstm_14 (LSTM) (None, 10, 16) 1600 \n", " \n", " dense_9 (Dense) (None, 10, 4) 68 \n", " \n", "=================================================================\n", "Total params: 1,668\n", "Trainable params: 1,668\n", "Non-trainable params: 0\n", "_________________________________________________________________\n", "None\n" ] } ], "source": [ "# Model Definition\n", "lstm_model = keras.Sequential()\n", "lstm_model.add(layers.Input(shape=(TIME_STEP_INPUT, 8)))\n", "lstm_model.add(layers.LSTM(LSTM_UNITS, stateful=False, return_sequences=True))\n", "# lstm_model.add(layers.LSTM(LSTM_UNITS, stateful=False, return_sequences=True))\n", "lstm_model.add(layers.Dense(SLEEP_STAGES))\n", "lstm_model.build()\n", "print(lstm_model.summary())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Model Training\n", "lstm_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n", "lstm_optm = keras.optimizers.Adam(learning_rate=LSTM_LEARNING_RATE)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model 2: GRU" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Hyper-paramters\n", "GRU_UNITS = 16" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "gru_model = keras.Sequential([\n", " layers.GRU(GRU_UNITS),\n", " layers.Dense(SLEEP_STAGES)\n", "])\n", "gru_model.add(layers.Embedding(input_dim=1000, output_dim=64))\n", "print(gru_model.summary())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model 3: Attention Mechanism" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ATTENTION_UNITS = 16" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "am_model = keras.Sequential([\n", " Attention(ATTENTION_UNITS)\n", " layers.Dense(SLEEP_STAGES)\n", "])\n", "print(am_model.summary())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model Head-to-Head testing" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Scratch" ] } ], "metadata": { "kernelspec": { "display_name": "final_lab", "language": "python", "name": "final_lab" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }