diff --git a/attention.py b/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..69a94dc5f9e7b31209dc252e9912404ea178c1a0
--- /dev/null
+++ b/attention.py
@@ -0,0 +1,80 @@
+# Provided by under the Apache License, Version 2.0 (https://www.apache.org/licenses/LICENSE-2.0)
+# Source: https://github.com/philipperemy/keras-attention-mechanism
+import os
+
+from tensorflow.keras import backend as K
+from tensorflow.keras.layers import Dense, Lambda, Dot, Activation, Concatenate, Layer
+
+# KERAS_ATTENTION_DEBUG: If set to 1. Will switch to debug mode.
+# In debug mode, the class Attention is no longer a Keras layer.
+# What it means in practice is that we can have access to the internal values
+# of each tensor. If we don't use debug, Keras treats the object
+# as a layer and we can only get the final output.
+debug_flag = int(os.environ.get('KERAS_ATTENTION_DEBUG', 0))
+
+
+class Attention(object if debug_flag else Layer):
+
+    def __init__(self, units=128, **kwargs):
+        super(Attention, self).__init__(**kwargs)
+        self.units = units
+
+    # noinspection PyAttributeOutsideInit
+    def build(self, input_shape):
+        input_dim = int(input_shape[-1])
+        with K.name_scope(self.name if not debug_flag else 'attention'):
+            self.attention_score_vec = Dense(input_dim, use_bias=False, name='attention_score_vec')
+            self.h_t = Lambda(lambda x: x[:, -1, :], output_shape=(input_dim,), name='last_hidden_state')
+            self.attention_score = Dot(axes=[1, 2], name='attention_score')
+            self.attention_weight = Activation('softmax', name='attention_weight')
+            self.context_vector = Dot(axes=[1, 1], name='context_vector')
+            self.attention_output = Concatenate(name='attention_output')
+            self.attention_vector = Dense(self.units, use_bias=False, activation='tanh', name='attention_vector')
+        if not debug_flag:
+            # debug: the call to build() is done in call().
+            super(Attention, self).build(input_shape)
+
+    def compute_output_shape(self, input_shape):
+        return input_shape[0], self.units
+
+    def __call__(self, inputs, training=None, **kwargs):
+        if debug_flag:
+            return self.call(inputs, training, **kwargs)
+        else:
+            return super(Attention, self).__call__(inputs, training, **kwargs)
+
+    # noinspection PyUnusedLocal
+    def call(self, inputs, training=None, **kwargs):
+        """
+        Many-to-one attention mechanism for Keras.
+        @param inputs: 3D tensor with shape (batch_size, time_steps, input_dim).
+        @param training: not used in this layer.
+        @return: 2D tensor with shape (batch_size, units)
+        @author: felixhao28, philipperemy.
+        """
+        if debug_flag:
+            self.build(inputs.shape)
+        # Inside dense layer
+        #              hidden_states            dot               W            =>           score_first_part
+        # (batch_size, time_steps, hidden_size) dot (hidden_size, hidden_size) => (batch_size, time_steps, hidden_size)
+        # W is the trainable weight matrix of attention Luong's multiplicative style score
+        score_first_part = self.attention_score_vec(inputs)
+        #            score_first_part           dot        last_hidden_state     => attention_weights
+        # (batch_size, time_steps, hidden_size) dot   (batch_size, hidden_size)  => (batch_size, time_steps)
+        h_t = self.h_t(inputs)
+        score = self.attention_score([h_t, score_first_part])
+        attention_weights = self.attention_weight(score)
+        # (batch_size, time_steps, hidden_size) dot (batch_size, time_steps) => (batch_size, hidden_size)
+        context_vector = self.context_vector([inputs, attention_weights])
+        pre_activation = self.attention_output([context_vector, h_t])
+        attention_vector = self.attention_vector(pre_activation)
+        return attention_vector
+
+    def get_config(self):
+        """
+        Returns the config of a the layer. This is used for saving and loading from a model
+        :return: python dictionary with specs to rebuild layer
+        """
+        config = super(Attention, self).get_config()
+        config.update({'units': self.units})
+        return config
\ No newline at end of file
diff --git a/tf_model.ipynb b/tf_model.ipynb
index 259961689cdafba423c4f4cf1f2bd237855cab12..efcdfa79fba233262de74809a897831726664598 100644
--- a/tf_model.ipynb
+++ b/tf_model.ipynb
@@ -14,9 +14,11 @@
    "outputs": [],
    "source": [
     "import tensorflow as tf\n",
+    "import numpy as np\n",
     "from tensorflow import keras\n",
     "from tensorflow.keras import layers\n",
-    "import pandas as pd"
+    "import pandas as pd\n",
+    "from attention import Attention"
    ]
   },
   {
@@ -38,7 +40,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 14,
+   "execution_count": 4,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -47,20 +49,14 @@
     "CLEANED_SLEEP_DATA_PATH = \".data/clean_bed_sleep-state.csv\""
    ]
   },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Parameters and Hyper-parameters"
-   ]
-  },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 5,
    "metadata": {},
    "outputs": [],
    "source": [
-    "SLEEP_STAGES = 4\n"
+    "## Parameters and Hyper-parameters\n",
+    "SLEEP_STAGES = 4"
    ]
   },
   {
@@ -140,7 +136,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 156,
+   "execution_count": 201,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -149,12 +145,14 @@
     "previous_duration = 60\n",
     "with open(RAW_SLEEP_DATA_PATH, mode ='r') as raw_file:\n",
     "    csvFile = csv.reader(raw_file)\n",
-    "    max_count = 1\n",
-    "    stuff = set()\n",
+    "    # max_count = 1\n",
+    "    # stuff = set()\n",
     "    in_bed_time = None\n",
+    "    current_sleep_id = -1\n",
     "    for index, lines in enumerate(csvFile):\n",
     "        if index == 0:\n",
     "            cleaned_info.append([\n",
+    "                \"sleep_id\",\n",
     "                \"sleep_begin\",\n",
     "                \"stage_start\",\n",
     "                \"time_since_begin_sec\",\n",
@@ -172,6 +170,7 @@
     "            continue\n",
     "        date_seen.add(start_time)\n",
     "        if not in_bed_time or in_bed_time > start_time:\n",
+    "            current_sleep_id += 1\n",
     "            in_bed_time = start_time\n",
     "        # for duration, stage in enumerate(\n",
     "        # for offset, (duration, stage) in enumerate(\n",
@@ -187,6 +186,7 @@
     "            # print(f\"{(index, duration) = } {stage = }\")\n",
     "            current_time = start_time + datetime.timedelta(seconds=offset*previous_duration)\n",
     "            cleaned_info.append([\n",
+    "                current_sleep_id,\n",
     "                in_bed_time,\n",
     "                current_time, \n",
     "                (current_time - in_bed_time).seconds,\n",
@@ -207,13 +207,22 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 157,
+   "execution_count": 202,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Finished Writing Cleaned Data\n"
+     ]
+    }
+   ],
    "source": [
     "with open(CLEANED_SLEEP_DATA_PATH, 'w') as clean_file:\n",
     "    write = csv.writer(clean_file)\n",
-    "    write.writerows(cleaned_info)"
+    "    write.writerows(cleaned_info)\n",
+    "print(\"Finished Writing Cleaned Data\")"
    ]
   },
   {
@@ -225,7 +234,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 158,
+   "execution_count": 6,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -235,7 +244,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 160,
+   "execution_count": 7,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -244,12 +253,36 @@
     "sleep_df_raw[\"sleep_begin\"] = pd.to_datetime(sleep_df_raw[\"sleep_begin\"], utc=True)\n",
     "sleep_df_raw[\"stage_start\"] = pd.to_datetime(sleep_df_raw[\"stage_start\"], utc=True)\n",
     "sleep_df_raw[\"stage_end\"] = pd.to_datetime(sleep_df_raw[\"stage_end\"], utc=True)\n",
-    "#   MAYBE 2. smaller units: int16 or int8 "
+    "#   2. Separate time, hour and minute\n",
+    "#   MAYBE 3. smaller units: int16 or int8 "
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 161,
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_minute(row, index):\n",
+    "    return row[index].time().minute\n",
+    "\n",
+    "def get_hour(row, index):\n",
+    "    return row[index].time().hour"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "sleep_df_raw[\"stage_start_hour\"] = sleep_df_raw.apply (lambda row: get_hour(row, \"stage_start\"), axis=1)\n",
+    "sleep_df_raw[\"stage_start_minute\"] = sleep_df_raw.apply (lambda row: get_minute(row, \"stage_start\"), axis=1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
    "metadata": {},
    "outputs": [
     {
@@ -258,21 +291,24 @@
      "text": [
       "<class 'pandas.core.frame.DataFrame'>\n",
       "RangeIndex: 551042 entries, 0 to 551041\n",
-      "Data columns (total 10 columns):\n",
+      "Data columns (total 13 columns):\n",
       " #   Column                Non-Null Count   Dtype              \n",
       "---  ------                --------------   -----              \n",
-      " 0   sleep_begin           551042 non-null  datetime64[ns, UTC]\n",
-      " 1   stage_start           551042 non-null  datetime64[ns, UTC]\n",
-      " 2   time_since_begin_sec  551042 non-null  int64              \n",
-      " 3   stage_duration_sec    551042 non-null  int64              \n",
-      " 4   stage_end             551042 non-null  datetime64[ns, UTC]\n",
-      " 5   stage_value           551042 non-null  int64              \n",
-      " 6   awake_probability     551042 non-null  float64            \n",
-      " 7   light_probability     551042 non-null  float64            \n",
-      " 8   deep_probability      551042 non-null  float64            \n",
-      " 9   rem_probability       551042 non-null  float64            \n",
-      "dtypes: datetime64[ns, UTC](3), float64(4), int64(3)\n",
-      "memory usage: 42.0 MB\n"
+      " 0   sleep_id              551042 non-null  int64              \n",
+      " 1   sleep_begin           551042 non-null  datetime64[ns, UTC]\n",
+      " 2   stage_start           551042 non-null  datetime64[ns, UTC]\n",
+      " 3   time_since_begin_sec  551042 non-null  int64              \n",
+      " 4   stage_duration_sec    551042 non-null  int64              \n",
+      " 5   stage_end             551042 non-null  datetime64[ns, UTC]\n",
+      " 6   stage_value           551042 non-null  int64              \n",
+      " 7   awake_probability     551042 non-null  float64            \n",
+      " 8   light_probability     551042 non-null  float64            \n",
+      " 9   deep_probability      551042 non-null  float64            \n",
+      " 10  rem_probability       551042 non-null  float64            \n",
+      " 11  stage_start_hour      551042 non-null  int64              \n",
+      " 12  stage_start_minute    551042 non-null  int64              \n",
+      "dtypes: datetime64[ns, UTC](3), float64(4), int64(6)\n",
+      "memory usage: 54.7 MB\n"
      ]
     }
    ],
@@ -282,7 +318,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 162,
+   "execution_count": 11,
    "metadata": {},
    "outputs": [
     {
@@ -306,6 +342,7 @@
        "  <thead>\n",
        "    <tr style=\"text-align: right;\">\n",
        "      <th></th>\n",
+       "      <th>sleep_id</th>\n",
        "      <th>sleep_begin</th>\n",
        "      <th>stage_start</th>\n",
        "      <th>time_since_begin_sec</th>\n",
@@ -316,11 +353,14 @@
        "      <th>light_probability</th>\n",
        "      <th>deep_probability</th>\n",
        "      <th>rem_probability</th>\n",
+       "      <th>stage_start_hour</th>\n",
+       "      <th>stage_start_minute</th>\n",
        "    </tr>\n",
        "  </thead>\n",
        "  <tbody>\n",
        "    <tr>\n",
        "      <th>0</th>\n",
+       "      <td>0</td>\n",
        "      <td>2022-04-21 08:18:00+00:00</td>\n",
        "      <td>2022-04-21 08:18:00+00:00</td>\n",
        "      <td>0</td>\n",
@@ -331,9 +371,12 @@
        "      <td>0.0</td>\n",
        "      <td>0.0</td>\n",
        "      <td>0.0</td>\n",
+       "      <td>8</td>\n",
+       "      <td>18</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>1</th>\n",
+       "      <td>0</td>\n",
        "      <td>2022-04-21 08:18:00+00:00</td>\n",
        "      <td>2022-04-21 08:19:00+00:00</td>\n",
        "      <td>60</td>\n",
@@ -344,9 +387,12 @@
        "      <td>0.0</td>\n",
        "      <td>0.0</td>\n",
        "      <td>0.0</td>\n",
+       "      <td>8</td>\n",
+       "      <td>19</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>2</th>\n",
+       "      <td>0</td>\n",
        "      <td>2022-04-21 08:18:00+00:00</td>\n",
        "      <td>2022-04-21 08:20:00+00:00</td>\n",
        "      <td>120</td>\n",
@@ -357,9 +403,12 @@
        "      <td>0.0</td>\n",
        "      <td>0.0</td>\n",
        "      <td>0.0</td>\n",
+       "      <td>8</td>\n",
+       "      <td>20</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>3</th>\n",
+       "      <td>0</td>\n",
        "      <td>2022-04-21 08:18:00+00:00</td>\n",
        "      <td>2022-04-21 08:21:00+00:00</td>\n",
        "      <td>180</td>\n",
@@ -370,9 +419,12 @@
        "      <td>0.0</td>\n",
        "      <td>0.0</td>\n",
        "      <td>0.0</td>\n",
+       "      <td>8</td>\n",
+       "      <td>21</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>4</th>\n",
+       "      <td>0</td>\n",
        "      <td>2022-04-21 08:18:00+00:00</td>\n",
        "      <td>2022-04-21 08:22:00+00:00</td>\n",
        "      <td>240</td>\n",
@@ -383,6 +435,8 @@
        "      <td>0.0</td>\n",
        "      <td>0.0</td>\n",
        "      <td>0.0</td>\n",
+       "      <td>8</td>\n",
+       "      <td>22</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>...</th>\n",
@@ -396,9 +450,13 @@
        "      <td>...</td>\n",
        "      <td>...</td>\n",
        "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>551037</th>\n",
+       "      <td>1132</td>\n",
        "      <td>2019-02-11 06:11:00+00:00</td>\n",
        "      <td>2019-02-11 13:17:00+00:00</td>\n",
        "      <td>25560</td>\n",
@@ -409,9 +467,12 @@
        "      <td>1.0</td>\n",
        "      <td>0.0</td>\n",
        "      <td>0.0</td>\n",
+       "      <td>13</td>\n",
+       "      <td>17</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>551038</th>\n",
+       "      <td>1132</td>\n",
        "      <td>2019-02-11 06:11:00+00:00</td>\n",
        "      <td>2019-02-11 13:18:00+00:00</td>\n",
        "      <td>25620</td>\n",
@@ -422,9 +483,12 @@
        "      <td>1.0</td>\n",
        "      <td>0.0</td>\n",
        "      <td>0.0</td>\n",
+       "      <td>13</td>\n",
+       "      <td>18</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>551039</th>\n",
+       "      <td>1132</td>\n",
        "      <td>2019-02-11 06:11:00+00:00</td>\n",
        "      <td>2019-02-11 13:19:00+00:00</td>\n",
        "      <td>25680</td>\n",
@@ -435,9 +499,12 @@
        "      <td>1.0</td>\n",
        "      <td>0.0</td>\n",
        "      <td>0.0</td>\n",
+       "      <td>13</td>\n",
+       "      <td>19</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>551040</th>\n",
+       "      <td>1132</td>\n",
        "      <td>2019-02-11 06:11:00+00:00</td>\n",
        "      <td>2019-02-11 13:20:00+00:00</td>\n",
        "      <td>25740</td>\n",
@@ -448,9 +515,12 @@
        "      <td>1.0</td>\n",
        "      <td>0.0</td>\n",
        "      <td>0.0</td>\n",
+       "      <td>13</td>\n",
+       "      <td>20</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>551041</th>\n",
+       "      <td>1132</td>\n",
        "      <td>2019-02-11 06:11:00+00:00</td>\n",
        "      <td>2019-02-11 13:21:00+00:00</td>\n",
        "      <td>25800</td>\n",
@@ -461,25 +531,27 @@
        "      <td>1.0</td>\n",
        "      <td>0.0</td>\n",
        "      <td>0.0</td>\n",
+       "      <td>13</td>\n",
+       "      <td>21</td>\n",
        "    </tr>\n",
        "  </tbody>\n",
        "</table>\n",
-       "<p>551042 rows × 10 columns</p>\n",
+       "<p>551042 rows × 13 columns</p>\n",
        "</div>"
       ],
       "text/plain": [
-       "                     sleep_begin               stage_start  \\\n",
-       "0      2022-04-21 08:18:00+00:00 2022-04-21 08:18:00+00:00   \n",
-       "1      2022-04-21 08:18:00+00:00 2022-04-21 08:19:00+00:00   \n",
-       "2      2022-04-21 08:18:00+00:00 2022-04-21 08:20:00+00:00   \n",
-       "3      2022-04-21 08:18:00+00:00 2022-04-21 08:21:00+00:00   \n",
-       "4      2022-04-21 08:18:00+00:00 2022-04-21 08:22:00+00:00   \n",
-       "...                          ...                       ...   \n",
-       "551037 2019-02-11 06:11:00+00:00 2019-02-11 13:17:00+00:00   \n",
-       "551038 2019-02-11 06:11:00+00:00 2019-02-11 13:18:00+00:00   \n",
-       "551039 2019-02-11 06:11:00+00:00 2019-02-11 13:19:00+00:00   \n",
-       "551040 2019-02-11 06:11:00+00:00 2019-02-11 13:20:00+00:00   \n",
-       "551041 2019-02-11 06:11:00+00:00 2019-02-11 13:21:00+00:00   \n",
+       "        sleep_id               sleep_begin               stage_start  \\\n",
+       "0              0 2022-04-21 08:18:00+00:00 2022-04-21 08:18:00+00:00   \n",
+       "1              0 2022-04-21 08:18:00+00:00 2022-04-21 08:19:00+00:00   \n",
+       "2              0 2022-04-21 08:18:00+00:00 2022-04-21 08:20:00+00:00   \n",
+       "3              0 2022-04-21 08:18:00+00:00 2022-04-21 08:21:00+00:00   \n",
+       "4              0 2022-04-21 08:18:00+00:00 2022-04-21 08:22:00+00:00   \n",
+       "...          ...                       ...                       ...   \n",
+       "551037      1132 2019-02-11 06:11:00+00:00 2019-02-11 13:17:00+00:00   \n",
+       "551038      1132 2019-02-11 06:11:00+00:00 2019-02-11 13:18:00+00:00   \n",
+       "551039      1132 2019-02-11 06:11:00+00:00 2019-02-11 13:19:00+00:00   \n",
+       "551040      1132 2019-02-11 06:11:00+00:00 2019-02-11 13:20:00+00:00   \n",
+       "551041      1132 2019-02-11 06:11:00+00:00 2019-02-11 13:21:00+00:00   \n",
        "\n",
        "        time_since_begin_sec  stage_duration_sec                 stage_end  \\\n",
        "0                          0                  60 2022-04-21 08:19:00+00:00   \n",
@@ -507,23 +579,23 @@
        "551040            1                0.0                1.0               0.0   \n",
        "551041            1                0.0                1.0               0.0   \n",
        "\n",
-       "        rem_probability  \n",
-       "0                   0.0  \n",
-       "1                   0.0  \n",
-       "2                   0.0  \n",
-       "3                   0.0  \n",
-       "4                   0.0  \n",
-       "...                 ...  \n",
-       "551037              0.0  \n",
-       "551038              0.0  \n",
-       "551039              0.0  \n",
-       "551040              0.0  \n",
-       "551041              0.0  \n",
+       "        rem_probability  stage_start_hour  stage_start_minute  \n",
+       "0                   0.0                 8                  18  \n",
+       "1                   0.0                 8                  19  \n",
+       "2                   0.0                 8                  20  \n",
+       "3                   0.0                 8                  21  \n",
+       "4                   0.0                 8                  22  \n",
+       "...                 ...               ...                 ...  \n",
+       "551037              0.0                13                  17  \n",
+       "551038              0.0                13                  18  \n",
+       "551039              0.0                13                  19  \n",
+       "551040              0.0                13                  20  \n",
+       "551041              0.0                13                  21  \n",
        "\n",
-       "[551042 rows x 10 columns]"
+       "[551042 rows x 13 columns]"
       ]
      },
-     "execution_count": 162,
+     "execution_count": 11,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -534,35 +606,451 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 143,
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "sleep_data = sleep_df_raw[[\"sleep_id\", \"stage_start_hour\", \"stage_start_minute\", \"awake_probability\", \"rem_probability\",\"light_probability\", \"deep_probability\"]]\n",
+    "sleep_data.insert(loc=1, column=\"minutes_since_begin\" , value= sleep_df_raw[\"time_since_begin_sec\"]//60)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
    "metadata": {},
    "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "   sleep_id  minutes_since_begin  stage_start_hour  stage_start_minute  \\\n",
+      "0         0                    0                 8                  18   \n",
+      "1         0                    1                 8                  19   \n",
+      "2         0                    2                 8                  20   \n",
+      "3         0                    3                 8                  21   \n",
+      "4         0                    4                 8                  22   \n",
+      "\n",
+      "   awake_probability  rem_probability  light_probability  deep_probability  \n",
+      "0                1.0              0.0                0.0               0.0  \n",
+      "1                1.0              0.0                0.0               0.0  \n",
+      "2                1.0              0.0                0.0               0.0  \n",
+      "3                1.0              0.0                0.0               0.0  \n",
+      "4                1.0              0.0                0.0               0.0  \n",
+      "<class 'pandas.core.frame.DataFrame'>\n",
+      "RangeIndex: 551042 entries, 0 to 551041\n",
+      "Data columns (total 8 columns):\n",
+      " #   Column               Non-Null Count   Dtype  \n",
+      "---  ------               --------------   -----  \n",
+      " 0   sleep_id             551042 non-null  int64  \n",
+      " 1   minutes_since_begin  551042 non-null  int64  \n",
+      " 2   stage_start_hour     551042 non-null  int64  \n",
+      " 3   stage_start_minute   551042 non-null  int64  \n",
+      " 4   awake_probability    551042 non-null  float64\n",
+      " 5   rem_probability      551042 non-null  float64\n",
+      " 6   light_probability    551042 non-null  float64\n",
+      " 7   deep_probability     551042 non-null  float64\n",
+      "dtypes: float64(4), int64(4)\n",
+      "memory usage: 33.6 MB\n",
+      "None\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(sleep_data.head())\n",
+    "print(sleep_data.info())"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Model Development"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Helper functions and class"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "TEST_SIZE = 122\n",
+    "VALIDATION_SIZE = 183\n",
+    "TIME_STEP_INPUT = 10 # in minutes\n",
+    "BATCH_SIZE = 32"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def training_test_split_by_unique_index(data, index: str, test_size: int = 10):\n",
+    "    test_ids = np.random.choice(data[index].unique(), size = test_size, replace=False)\n",
+    "    return data[~data[index].isin(test_ids)], data[data[index].isin(test_ids)]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Adapted from https://www.tensorflow.org/tutorials/structured_data/time_series\n",
+    "class WindowGenerator():\n",
+    "    def __init__(self, data, index: str = \"sleep_id\", input_width: int = TIME_STEP_INPUT, validation_size: int = VALIDATION_SIZE, test_size: int = TEST_SIZE, generate_data_now: bool = True):\n",
+    "        # Partition data\n",
+    "        self.training, self.testing = training_test_split_by_unique_index(sleep_data, index, test_size)\n",
+    "        self.training, self.validation = training_test_split_by_unique_index(self.training, index, validation_size)\n",
+    "\n",
+    "        # Window paramters\n",
+    "        self.input_width = input_width\n",
+    "        self.label_width = 1\n",
+    "        self.shift = 1\n",
+    "\n",
+    "        self.total_window_size = self.input_width + self.shift\n",
+    "\n",
+    "        self.input_slice = slice(0, input_width)\n",
+    "        self.input_indices = np.arange(self.total_window_size)[self.input_slice]\n",
+    "\n",
+    "        self.label_start = self.total_window_size - self.label_width\n",
+    "        self.labels_slice = slice(self.label_start, None)\n",
+    "        self.label_indices = np.arange(self.total_window_size)[self.labels_slice]\n",
+    "        \n",
+    "        if generate_data_now:\n",
+    "            self.training_ds = self.make_dataset(self.training)\n",
+    "            self.validation_ds = self.make_dataset(self.validation)\n",
+    "            self.testing_ds = self.make_dataset(self.testing)\n",
+    "\n",
+    "\n",
+    "    def __repr__(self):\n",
+    "        return \"WindowGenerator:\\n\\t\" +'\\n\\t'.join([\n",
+    "            f'Total window size: {self.total_window_size}',\n",
+    "            f'Input indices: {self.input_indices}',\n",
+    "            f'Label indices: {self.label_indices}',\n",
+    "        ])\n",
+    "\n",
+    "    def split_window(self, features):\n",
+    "        inputs = features[:, self.input_slice, :]\n",
+    "        labels = features[:, self.labels_slice, :]\n",
+    "        inputs.set_shape([None, self.input_width, None])\n",
+    "        labels.set_shape([None, self.label_width, None])\n",
+    "\n",
+    "        return inputs, labels\n",
+    "\n",
+    "    def make_dataset(self, data, index_group: str = \"sleep_id\", sort_by: str = \"minutes_since_begin\"):\n",
+    "        ds_all = None\n",
+    "        for i_group in data[index_group].unique():\n",
+    "            subset_data = np.array(data[data[index_group] == i_group].sort_values(by=[sort_by]), dtype=np.float32)\n",
+    "            ds = tf.keras.utils.timeseries_dataset_from_array(\n",
+    "                data=subset_data,\n",
+    "                targets=None,\n",
+    "                sequence_length=self.total_window_size,\n",
+    "                sequence_stride=1,\n",
+    "                shuffle=False,\n",
+    "                batch_size=BATCH_SIZE,)\n",
+    "            ds_all = ds if ds_all is None else ds_all.concatenate(ds)\n",
+    "        ds_all = ds_all.map(self.split_window)\n",
+    "\n",
+    "        return ds_all\n",
+    "\n",
+    "    # def generate_all_datasets(self):\n",
+    "    #     self._training_ds = self.make_dataset(self.training)\n",
+    "    #     self._validation_ds = self.make_dataset(self.validation)\n",
+    "    #     self._testing_ds = self.make_dataset(self.testing)\n",
+    "\n",
+    "    # def training_dataset(self):\n",
+    "    #     if self._training_ds is None:\n",
+    "    #         self._training_ds = self.make_dataset(self.training)\n",
+    "    #     return self._training_ds\n",
+    "\n",
+    "    # def validation_dataset(self):\n",
+    "    #     if self._validation_ds is None:\n",
+    "    #         self._validation_ds = self.make_dataset(self.validation)\n",
+    "    #     return self._validation_ds\n",
+    "\n",
+    "    # def test_dataset(self):\n",
+    "    #     if self._testing_ds is None:\n",
+    "    #         self._testing_ds = self.make_dataset(self.testing)\n",
+    "    #     return self._testing_ds"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "2022-05-11 00:38:01.423873: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
+      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
+     ]
+    },
     {
      "data": {
       "text/plain": [
-       "start                2019-02-11 06:11:00+00:00\n",
-       "duration                                    60\n",
-       "end                  2019-02-11 06:12:00+00:00\n",
-       "stage                                        0\n",
-       "awake_probability                          0.0\n",
-       "light_probability                          0.0\n",
-       "deep_probability                           0.0\n",
-       "rem_probability                            0.0\n",
-       "dtype: object"
+       "WindowGenerator:\n",
+       "\tTotal window size: 11\n",
+       "\tInput indices: [0 1 2 3 4 5 6 7 8 9]\n",
+       "\tLabel indices: [10]"
       ]
      },
-     "execution_count": 143,
+     "execution_count": 17,
      "metadata": {},
      "output_type": "execute_result"
     }
    ],
+   "source": [
+    "wg = WindowGenerator(sleep_data)\n",
+    "wg"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "1836"
+      ]
+     },
+     "execution_count": 20,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "len(wg.testing_ds)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Data Prep\n",
+    "\n",
+    "All inputs follow: (batch_size, timesteps, input_dim)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "NameError",
+     "evalue": "name 'training_test_split_by_unique_index' is not defined",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
+      "\u001b[1;32m/Users/nowadmin/Documents/School Folder/CS 437/Lab/Final Project/tf_model.ipynb Cell 32'\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> <a href='vscode-notebook-cell:/Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/tf_model.ipynb#ch0000068?line=0'>1</a>\u001b[0m training_sleep_data, test_sleep_data \u001b[39m=\u001b[39m training_test_split_by_unique_index(sleep_data, \u001b[39m\"\u001b[39m\u001b[39msleep_id\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m5\u001b[39m)\n",
+      "\u001b[0;31mNameError\u001b[0m: name 'training_test_split_by_unique_index' is not defined"
+     ]
+    }
+   ],
+   "source": [
+    "training_sleep_data, test_sleep_data = training_test_split_by_unique_index(sleep_data, \"sleep_id\", 5)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 288,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "set()"
+      ]
+     },
+     "execution_count": 288,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "set(training.sleep_id.unique()).intersection(set(test.sleep_id.unique()))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 239,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "sleep_data_tensor = tf.convert_to_tensor(sleep_data)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 240,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "TensorShape([551042, 7])"
+      ]
+     },
+     "execution_count": 240,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "sleep_data_tensor.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 241,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "<tf.Tensor: shape=(7,), dtype=float64, numpy=array([ 1.,  8., 19.,  1.,  0.,  0.,  0.])>"
+      ]
+     },
+     "execution_count": 241,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "sleep_data_tensor[1]"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Model 1: LSTM"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 215,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Hyper-parameters\n",
+    "LSTM_UNITS = 16\n",
+    "LSTM_LEARNING_RATE = 0.0001"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 278,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Model: \"sequential_12\"\n",
+      "_________________________________________________________________\n",
+      " Layer (type)                Output Shape              Param #   \n",
+      "=================================================================\n",
+      " lstm_14 (LSTM)              (None, 10, 16)            1600      \n",
+      "                                                                 \n",
+      " dense_9 (Dense)             (None, 10, 4)             68        \n",
+      "                                                                 \n",
+      "=================================================================\n",
+      "Total params: 1,668\n",
+      "Trainable params: 1,668\n",
+      "Non-trainable params: 0\n",
+      "_________________________________________________________________\n",
+      "None\n"
+     ]
+    }
+   ],
+   "source": [
+    "# Model Definition\n",
+    "lstm_model = keras.Sequential()\n",
+    "lstm_model.add(layers.Input(shape=(TIME_STEP_INPUT, 8)))\n",
+    "lstm_model.add(layers.LSTM(LSTM_UNITS, stateful=False, return_sequences=True))\n",
+    "# lstm_model.add(layers.LSTM(LSTM_UNITS, stateful=False, return_sequences=True))\n",
+    "lstm_model.add(layers.Dense(SLEEP_STAGES))\n",
+    "lstm_model.build()\n",
+    "print(lstm_model.summary())"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Model Training\n",
+    "lstm_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
+    "lstm_optm = keras.optimizers.Adam(learning_rate=LSTM_LEARNING_RATE)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Model 2: GRU"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Hyper-paramters\n",
+    "GRU_UNITS = 16"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "gru_model = keras.Sequential([\n",
+    "    layers.GRU(GRU_UNITS),\n",
+    "    layers.Dense(SLEEP_STAGES)\n",
+    "])\n",
+    "gru_model.add(layers.Embedding(input_dim=1000, output_dim=64))\n",
+    "print(gru_model.summary())"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
    "source": []
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## Model 1: GRU"
+    "### Model 3: Attention Mechanism"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "ATTENTION_UNITS = 16"
    ]
   },
   {
@@ -571,15 +1059,32 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "model = keras.Sequential()\n",
-    "model.add(layers.Embedding(input_dim=1000, output_dim=64))\n",
-    "model.add(layers.GRU(128))\n",
-    "model.add(layers.Dense(10))"
+    "am_model = keras.Sequential([\n",
+    "    Attention(ATTENTION_UNITS)\n",
+    "    layers.Dense(SLEEP_STAGES)\n",
+    "])\n",
+    "print(am_model.summary())"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
   {
    "cell_type": "markdown",
    "metadata": {},
+   "source": [
+    "### Model Head-to-Head testing"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
    "source": []
   },
   {