 "cells": [
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Notebook for training and testing AI models"
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "import pandas as pd"
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
     "name": "stdout",
     "output_type": "stream",
     "text": [
   "source": [
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CONSTANTS\n",
    "RAW_SLEEP_DATA_PATH = \".data/raw_bed_sleep-state.csv\"\n",
    "CLEANED_SLEEP_DATA_PATH = \".data/clean_bed_sleep-state.csv\""
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameters and Hyper-parameters"
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "SLEEP_STAGES = 4\n"
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import Data"
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cleaning Raw Data"
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "import datetime\n",
    "import itertools"
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [
     "data": {
      "text/plain": [
       "datetime.datetime(2022, 4, 21, 10, 19, tzinfo=datetime.timezone(datetime.timedelta(seconds=7200)))"
     "execution_count": 104,
     "metadata": {},
     "output_type": "execute_result"
   "source": [
    "datetime.datetime.strptime(\"2022-04-21T10:18:00+02:00\",\"%Y-%m-%dT%H:%M:%S%z\") + datetime.timedelta(minutes=1)"
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [],
   "source": [
    "def stage_probability(stage, to_test):\n",
    "    return 1.0 if stage == to_test else 0.0"
   "cell_type": "code",
   "execution_count": 150,
   "metadata": {},
   "outputs": [
     "data": {
      "text/plain": [
     "execution_count": 150,
     "metadata": {},
     "output_type": "execute_result"
   "source": [
    "((start_time - in_bed_time).seconds)/3600"
   "cell_type": "code",
   "execution_count": 156,
   "metadata": {},
   "outputs": [],
   "source": [
    "cleaned_info = []\n",
    "date_seen = set()\n",
    "previous_duration = 60\n",
    "with open(RAW_SLEEP_DATA_PATH, mode ='r') as raw_file:\n",
    "    csvFile = csv.reader(raw_file)\n",
    "    max_count = 1\n",
    "    stuff = set()\n",
    "    in_bed_time = None\n",
    "    for index, lines in enumerate(csvFile):\n",
    "        if index == 0:\n",
    "            cleaned_info.append([\n",
    "                \"sleep_begin\",\n",
    "                \"stage_start\",\n",
    "                \"time_since_begin_sec\",\n",
    "                \"stage_duration_sec\",\n",
    "                \"stage_end\", \n",
    "                \"stage_value\",\n",
    "                \"awake_probability\",\n",
    "                \"light_probability\",\n",
    "                \"deep_probability\",\n",
    "                \"rem_probability\",\n",
    "            ])\n",
    "            continue\n",
    "        start_time = datetime.datetime.strptime(lines[0],\"%Y-%m-%dT%H:%M:%S%z\")\n",
    "        if start_time in date_seen:\n",
    "            continue\n",
    "        date_seen.add(start_time)\n",
    "        if not in_bed_time or in_bed_time > start_time:\n",
    "            in_bed_time = start_time\n",
    "        # for duration, stage in enumerate(\n",
    "        # for offset, (duration, stage) in enumerate(\n",
    "        #     zip(\n",
    "        #         # itertools.accumulate(lines[1].strip(\"[]\").split(\",\"), lambda x,y: int(x)+int(y)//60, initial = 0), \n",
    "        #         map(int, lines[1].strip(\"[]\").split(\",\"))\n",
    "        #         map(int, lines[2].strip(\"[]\").split(\",\"))\n",
    "        #     )\n",
    "        #         # map(int, lines[2].strip(\"[]\").split(\",\"))\n",
    "        # ):\n",
    "        for offset, (duration, stage) in enumerate(zip(map(int, lines[1].strip(\"[]\").split(\",\")), map(int, lines[2].strip(\"[]\").split(\",\")))):\n",
    "            # print(f\"{(index, subindex) = }, {duration = }, {stage = }\")\n",
    "            # print(f\"{(index, duration) = } {stage = }\")\n",
    "            current_time = start_time + datetime.timedelta(seconds=offset*previous_duration)\n",
    "            cleaned_info.append([\n",
    "                in_bed_time,\n",
    "                current_time, \n",
    "                (current_time - in_bed_time).seconds,\n",
    "                duration, \n",
    "                current_time + datetime.timedelta(seconds=duration), \n",
    "                stage,\n",
    "                stage_probability(0, stage),\n",
    "                stage_probability(1, stage),\n",
    "                stage_probability(2, stage),\n",
    "                stage_probability(3, stage),\n",
    "            ])\n",
    "            previous_duration = duration\n",
    "            # print(f\"{(index, subindex) = }, {val = }\")\n",
    "        # print(list())\n",
    "        # if index >= max_count:\n",
    "        #     break\n"
   "cell_type": "code",
   "execution_count": 157,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(CLEANED_SLEEP_DATA_PATH, 'w') as clean_file:\n",
    "    write = csv.writer(clean_file)\n",
    "    write.writerows(cleaned_info)"
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating DataFrame from clean raw data"
   "cell_type": "code",
   "execution_count": 158,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the cleaned data\n",
    "sleep_df_raw = pd.read_csv(CLEANED_SLEEP_DATA_PATH)#, parse_dates=[\"start\", \"end\"], infer_datetime_format=True)"
   "cell_type": "code",
   "execution_count": 160,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess data: \n",
    "#   1. convert to datetime\n",
    "sleep_df_raw[\"sleep_begin\"] = pd.to_datetime(sleep_df_raw[\"sleep_begin\"], utc=True)\n",
    "sleep_df_raw[\"stage_start\"] = pd.to_datetime(sleep_df_raw[\"stage_start\"], utc=True)\n",
    "sleep_df_raw[\"stage_end\"] = pd.to_datetime(sleep_df_raw[\"stage_end\"], utc=True)\n",
    "#   MAYBE 2. smaller units: int16 or int8 "
   "cell_type": "code",
   "execution_count": 161,
   "metadata": {},
   "outputs": [
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 551042 entries, 0 to 551041\n",
      "Data columns (total 10 columns):\n",
      " #   Column                Non-Null Count   Dtype              \n",
      "---  ------                --------------   -----              \n",
      " 0   sleep_begin           551042 non-null  datetime64[ns, UTC]\n",
      " 1   stage_start           551042 non-null  datetime64[ns, UTC]\n",
      " 2   time_since_begin_sec  551042 non-null  int64              \n",
      " 3   stage_duration_sec    551042 non-null  int64              \n",
      " 4   stage_end             551042 non-null  datetime64[ns, UTC]\n",
      " 5   stage_value           551042 non-null  int64              \n",
      " 6   awake_probability     551042 non-null  float64            \n",
      " 7   light_probability     551042 non-null  float64            \n",
      " 8   deep_probability      551042 non-null  float64            \n",
      " 9   rem_probability       551042 non-null  float64            \n",
      "dtypes: datetime64[ns, UTC](3), float64(4), int64(3)\n",
      "memory usage: 42.0 MB\n"
   "source": [
   "cell_type": "code",
   "execution_count": 162,
   "metadata": {},
   "outputs": [
     "data": {
      "text/html": [
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sleep_begin</th>\n",
       "      <th>stage_start</th>\n",
       "      <th>time_since_begin_sec</th>\n",
       "      <th>stage_duration_sec</th>\n",
       "      <th>stage_end</th>\n",
       "      <th>stage_value</th>\n",
       "      <th>awake_probability</th>\n",
       "      <th>light_probability</th>\n",
       "      <th>deep_probability</th>\n",
       "      <th>rem_probability</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:19:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:19:00+00:00</td>\n",
       "      <td>60</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:20:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:20:00+00:00</td>\n",
       "      <td>120</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:21:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:21:00+00:00</td>\n",
       "      <td>180</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:22:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:22:00+00:00</td>\n",
       "      <td>240</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:23:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551037</th>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:17:00+00:00</td>\n",
       "      <td>25560</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:18:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551038</th>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:18:00+00:00</td>\n",
       "      <td>25620</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:19:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551039</th>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:19:00+00:00</td>\n",
       "      <td>25680</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:20:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551040</th>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:20:00+00:00</td>\n",
       "      <td>25740</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:21:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551041</th>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:21:00+00:00</td>\n",
       "      <td>25800</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:22:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "<p>551042 rows × 10 columns</p>\n",
      "text/plain": [
       "                     sleep_begin               stage_start  \\\n",
       "0      2022-04-21 08:18:00+00:00 2022-04-21 08:18:00+00:00   \n",
       "1      2022-04-21 08:18:00+00:00 2022-04-21 08:19:00+00:00   \n",
       "2      2022-04-21 08:18:00+00:00 2022-04-21 08:20:00+00:00   \n",
       "3      2022-04-21 08:18:00+00:00 2022-04-21 08:21:00+00:00   \n",
       "4      2022-04-21 08:18:00+00:00 2022-04-21 08:22:00+00:00   \n",
       "...                          ...                       ...   \n",
       "551037 2019-02-11 06:11:00+00:00 2019-02-11 13:17:00+00:00   \n",
       "551038 2019-02-11 06:11:00+00:00 2019-02-11 13:18:00+00:00   \n",
       "551039 2019-02-11 06:11:00+00:00 2019-02-11 13:19:00+00:00   \n",
       "551040 2019-02-11 06:11:00+00:00 2019-02-11 13:20:00+00:00   \n",
       "551041 2019-02-11 06:11:00+00:00 2019-02-11 13:21:00+00:00   \n",
       "        time_since_begin_sec  stage_duration_sec                 stage_end  \\\n",
       "0                          0                  60 2022-04-21 08:19:00+00:00   \n",
       "1                         60                  60 2022-04-21 08:20:00+00:00   \n",
       "2                        120                  60 2022-04-21 08:21:00+00:00   \n",
       "3                        180                  60 2022-04-21 08:22:00+00:00   \n",
       "4                        240                  60 2022-04-21 08:23:00+00:00   \n",
       "...                      ...                 ...                       ...   \n",
       "551037                 25560                  60 2019-02-11 13:18:00+00:00   \n",
       "551038                 25620                  60 2019-02-11 13:19:00+00:00   \n",
       "551039                 25680                  60 2019-02-11 13:20:00+00:00   \n",
       "551040                 25740                  60 2019-02-11 13:21:00+00:00   \n",
       "551041                 25800                  60 2019-02-11 13:22:00+00:00   \n",
       "        stage_value  awake_probability  light_probability  deep_probability  \\\n",
       "0                 0                1.0                0.0               0.0   \n",
       "1                 0                1.0                0.0               0.0   \n",
       "2                 0                1.0                0.0               0.0   \n",
       "3                 0                1.0                0.0               0.0   \n",
       "4                 0                1.0                0.0               0.0   \n",
       "...             ...                ...                ...               ...   \n",
       "551037            1                0.0                1.0               0.0   \n",
       "551038            1                0.0                1.0               0.0   \n",
       "551039            1                0.0                1.0               0.0   \n",
       "551040            1                0.0                1.0               0.0   \n",
       "551041            1                0.0                1.0               0.0   \n",
       "        rem_probability  \n",
       "0                   0.0  \n",
       "1                   0.0  \n",
       "2                   0.0  \n",
       "3                   0.0  \n",
       "4                   0.0  \n",
       "...                 ...  \n",
       "551037              0.0  \n",
       "551038              0.0  \n",
       "551039              0.0  \n",
       "551040              0.0  \n",
       "551041              0.0  \n",
       "[551042 rows x 10 columns]"
     "execution_count": 162,
     "metadata": {},
     "output_type": "execute_result"
   "source": [
   "cell_type": "code",
   "execution_count": 143,
   "metadata": {},
   "outputs": [
     "data": {
      "text/plain": [
       "start                2019-02-11 06:11:00+00:00\n",
       "duration                                    60\n",
       "end                  2019-02-11 06:12:00+00:00\n",
       "stage                                        0\n",
       "awake_probability                          0.0\n",
       "light_probability                          0.0\n",
       "deep_probability                           0.0\n",
       "rem_probability                            0.0\n",
       "dtype: object"
     "execution_count": 143,
     "metadata": {},
     "output_type": "execute_result"
   "source": []
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model 1: GRU"
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.Sequential()\n",
    "model.add(layers.Embedding(input_dim=1000, output_dim=64))\n",
   "cell_type": "markdown",
   "metadata": {},
   "source": []
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Scratch"
 "metadata": {
  "kernelspec": {
   "display_name": "final_lab",
   "language": "python",
   "name": "final_lab"
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  "orig_nbformat": 4
 "nbformat": 4,
 "nbformat_minor": 2