{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Notebook for training and testing AI models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.8.0\n"
     ]
    }
   ],
   "source": [
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CONSTANTS\n",
    "RAW_SLEEP_DATA_PATH = \".data/raw_bed_sleep-state.csv\"\n",
    "CLEANED_SLEEP_DATA_PATH = \".data/clean_bed_sleep-state.csv\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameters and Hyper-parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "SLEEP_STAGES = 4\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cleaning Raw Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "import datetime\n",
    "import itertools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "datetime.datetime(2022, 4, 21, 10, 19, tzinfo=datetime.timezone(datetime.timedelta(seconds=7200)))"
      ]
     },
     "execution_count": 104,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "datetime.datetime.strptime(\"2022-04-21T10:18:00+02:00\",\"%Y-%m-%dT%H:%M:%S%z\") + datetime.timedelta(minutes=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [],
   "source": [
    "def stage_probability(stage, to_test):\n",
    "    return 1.0 if stage == to_test else 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "6.833333333333333"
      ]
     },
     "execution_count": 150,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "((start_time - in_bed_time).seconds)/3600"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "metadata": {},
   "outputs": [],
   "source": [
    "cleaned_info = []\n",
    "date_seen = set()\n",
    "previous_duration = 60\n",
    "with open(RAW_SLEEP_DATA_PATH, mode ='r') as raw_file:\n",
    "    csvFile = csv.reader(raw_file)\n",
    "    max_count = 1\n",
    "    stuff = set()\n",
    "    in_bed_time = None\n",
    "    for index, lines in enumerate(csvFile):\n",
    "        if index == 0:\n",
    "            cleaned_info.append([\n",
    "                \"sleep_begin\",\n",
    "                \"stage_start\",\n",
    "                \"time_since_begin_sec\",\n",
    "                \"stage_duration_sec\",\n",
    "                \"stage_end\", \n",
    "                \"stage_value\",\n",
    "                \"awake_probability\",\n",
    "                \"light_probability\",\n",
    "                \"deep_probability\",\n",
    "                \"rem_probability\",\n",
    "            ])\n",
    "            continue\n",
    "        start_time = datetime.datetime.strptime(lines[0],\"%Y-%m-%dT%H:%M:%S%z\")\n",
    "        if start_time in date_seen:\n",
    "            continue\n",
    "        date_seen.add(start_time)\n",
    "        if not in_bed_time or in_bed_time > start_time:\n",
    "            in_bed_time = start_time\n",
    "        # for duration, stage in enumerate(\n",
    "        # for offset, (duration, stage) in enumerate(\n",
    "        #     zip(\n",
    "        #         # itertools.accumulate(lines[1].strip(\"[]\").split(\",\"), lambda x,y: int(x)+int(y)//60, initial = 0), \n",
    "        #         map(int, lines[1].strip(\"[]\").split(\",\"))\n",
    "        #         map(int, lines[2].strip(\"[]\").split(\",\"))\n",
    "        #     )\n",
    "        #         # map(int, lines[2].strip(\"[]\").split(\",\"))\n",
    "        # ):\n",
    "        for offset, (duration, stage) in enumerate(zip(map(int, lines[1].strip(\"[]\").split(\",\")), map(int, lines[2].strip(\"[]\").split(\",\")))):\n",
    "            # print(f\"{(index, subindex) = }, {duration = }, {stage = }\")\n",
    "            # print(f\"{(index, duration) = } {stage = }\")\n",
    "            current_time = start_time + datetime.timedelta(seconds=offset*previous_duration)\n",
    "            cleaned_info.append([\n",
    "                in_bed_time,\n",
    "                current_time, \n",
    "                (current_time - in_bed_time).seconds,\n",
    "                duration, \n",
    "                current_time + datetime.timedelta(seconds=duration), \n",
    "                stage,\n",
    "                stage_probability(0, stage),\n",
    "                stage_probability(1, stage),\n",
    "                stage_probability(2, stage),\n",
    "                stage_probability(3, stage),\n",
    "            ])\n",
    "            previous_duration = duration\n",
    "            # print(f\"{(index, subindex) = }, {val = }\")\n",
    "        # print(list())\n",
    "        # if index >= max_count:\n",
    "        #     break\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(CLEANED_SLEEP_DATA_PATH, 'w') as clean_file:\n",
    "    write = csv.writer(clean_file)\n",
    "    write.writerows(cleaned_info)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating DataFrame from clean raw data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the cleaned data\n",
    "sleep_df_raw = pd.read_csv(CLEANED_SLEEP_DATA_PATH)#, parse_dates=[\"start\", \"end\"], infer_datetime_format=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 160,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess data: \n",
    "#   1. convert to datetime\n",
    "sleep_df_raw[\"sleep_begin\"] = pd.to_datetime(sleep_df_raw[\"sleep_begin\"], utc=True)\n",
    "sleep_df_raw[\"stage_start\"] = pd.to_datetime(sleep_df_raw[\"stage_start\"], utc=True)\n",
    "sleep_df_raw[\"stage_end\"] = pd.to_datetime(sleep_df_raw[\"stage_end\"], utc=True)\n",
    "#   MAYBE 2. smaller units: int16 or int8 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 161,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 551042 entries, 0 to 551041\n",
      "Data columns (total 10 columns):\n",
      " #   Column                Non-Null Count   Dtype              \n",
      "---  ------                --------------   -----              \n",
      " 0   sleep_begin           551042 non-null  datetime64[ns, UTC]\n",
      " 1   stage_start           551042 non-null  datetime64[ns, UTC]\n",
      " 2   time_since_begin_sec  551042 non-null  int64              \n",
      " 3   stage_duration_sec    551042 non-null  int64              \n",
      " 4   stage_end             551042 non-null  datetime64[ns, UTC]\n",
      " 5   stage_value           551042 non-null  int64              \n",
      " 6   awake_probability     551042 non-null  float64            \n",
      " 7   light_probability     551042 non-null  float64            \n",
      " 8   deep_probability      551042 non-null  float64            \n",
      " 9   rem_probability       551042 non-null  float64            \n",
      "dtypes: datetime64[ns, UTC](3), float64(4), int64(3)\n",
      "memory usage: 42.0 MB\n"
     ]
    }
   ],
   "source": [
    "sleep_df_raw.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 162,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sleep_begin</th>\n",
       "      <th>stage_start</th>\n",
       "      <th>time_since_begin_sec</th>\n",
       "      <th>stage_duration_sec</th>\n",
       "      <th>stage_end</th>\n",
       "      <th>stage_value</th>\n",
       "      <th>awake_probability</th>\n",
       "      <th>light_probability</th>\n",
       "      <th>deep_probability</th>\n",
       "      <th>rem_probability</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:19:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:19:00+00:00</td>\n",
       "      <td>60</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:20:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:20:00+00:00</td>\n",
       "      <td>120</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:21:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:21:00+00:00</td>\n",
       "      <td>180</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:22:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2022-04-21 08:18:00+00:00</td>\n",
       "      <td>2022-04-21 08:22:00+00:00</td>\n",
       "      <td>240</td>\n",
       "      <td>60</td>\n",
       "      <td>2022-04-21 08:23:00+00:00</td>\n",
       "      <td>0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551037</th>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:17:00+00:00</td>\n",
       "      <td>25560</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:18:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551038</th>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:18:00+00:00</td>\n",
       "      <td>25620</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:19:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551039</th>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:19:00+00:00</td>\n",
       "      <td>25680</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:20:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551040</th>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:20:00+00:00</td>\n",
       "      <td>25740</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:21:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>551041</th>\n",
       "      <td>2019-02-11 06:11:00+00:00</td>\n",
       "      <td>2019-02-11 13:21:00+00:00</td>\n",
       "      <td>25800</td>\n",
       "      <td>60</td>\n",
       "      <td>2019-02-11 13:22:00+00:00</td>\n",
       "      <td>1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>551042 rows × 10 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                     sleep_begin               stage_start  \\\n",
       "0      2022-04-21 08:18:00+00:00 2022-04-21 08:18:00+00:00   \n",
       "1      2022-04-21 08:18:00+00:00 2022-04-21 08:19:00+00:00   \n",
       "2      2022-04-21 08:18:00+00:00 2022-04-21 08:20:00+00:00   \n",
       "3      2022-04-21 08:18:00+00:00 2022-04-21 08:21:00+00:00   \n",
       "4      2022-04-21 08:18:00+00:00 2022-04-21 08:22:00+00:00   \n",
       "...                          ...                       ...   \n",
       "551037 2019-02-11 06:11:00+00:00 2019-02-11 13:17:00+00:00   \n",
       "551038 2019-02-11 06:11:00+00:00 2019-02-11 13:18:00+00:00   \n",
       "551039 2019-02-11 06:11:00+00:00 2019-02-11 13:19:00+00:00   \n",
       "551040 2019-02-11 06:11:00+00:00 2019-02-11 13:20:00+00:00   \n",
       "551041 2019-02-11 06:11:00+00:00 2019-02-11 13:21:00+00:00   \n",
       "\n",
       "        time_since_begin_sec  stage_duration_sec                 stage_end  \\\n",
       "0                          0                  60 2022-04-21 08:19:00+00:00   \n",
       "1                         60                  60 2022-04-21 08:20:00+00:00   \n",
       "2                        120                  60 2022-04-21 08:21:00+00:00   \n",
       "3                        180                  60 2022-04-21 08:22:00+00:00   \n",
       "4                        240                  60 2022-04-21 08:23:00+00:00   \n",
       "...                      ...                 ...                       ...   \n",
       "551037                 25560                  60 2019-02-11 13:18:00+00:00   \n",
       "551038                 25620                  60 2019-02-11 13:19:00+00:00   \n",
       "551039                 25680                  60 2019-02-11 13:20:00+00:00   \n",
       "551040                 25740                  60 2019-02-11 13:21:00+00:00   \n",
       "551041                 25800                  60 2019-02-11 13:22:00+00:00   \n",
       "\n",
       "        stage_value  awake_probability  light_probability  deep_probability  \\\n",
       "0                 0                1.0                0.0               0.0   \n",
       "1                 0                1.0                0.0               0.0   \n",
       "2                 0                1.0                0.0               0.0   \n",
       "3                 0                1.0                0.0               0.0   \n",
       "4                 0                1.0                0.0               0.0   \n",
       "...             ...                ...                ...               ...   \n",
       "551037            1                0.0                1.0               0.0   \n",
       "551038            1                0.0                1.0               0.0   \n",
       "551039            1                0.0                1.0               0.0   \n",
       "551040            1                0.0                1.0               0.0   \n",
       "551041            1                0.0                1.0               0.0   \n",
       "\n",
       "        rem_probability  \n",
       "0                   0.0  \n",
       "1                   0.0  \n",
       "2                   0.0  \n",
       "3                   0.0  \n",
       "4                   0.0  \n",
       "...                 ...  \n",
       "551037              0.0  \n",
       "551038              0.0  \n",
       "551039              0.0  \n",
       "551040              0.0  \n",
       "551041              0.0  \n",
       "\n",
       "[551042 rows x 10 columns]"
      ]
     },
     "execution_count": 162,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sleep_df_raw"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "start                2019-02-11 06:11:00+00:00\n",
       "duration                                    60\n",
       "end                  2019-02-11 06:12:00+00:00\n",
       "stage                                        0\n",
       "awake_probability                          0.0\n",
       "light_probability                          0.0\n",
       "deep_probability                           0.0\n",
       "rem_probability                            0.0\n",
       "dtype: object"
      ]
     },
     "execution_count": 143,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model 1: GRU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = keras.Sequential()\n",
    "model.add(layers.Embedding(input_dim=1000, output_dim=64))\n",
    "model.add(layers.GRU(128))\n",
    "model.add(layers.Dense(10))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Scratch"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "final_lab",
   "language": "python",
   "name": "final_lab"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}