diff --git a/.gitignore b/.gitignore
index 8fa6d85bae3877def6513c6d9fe4a837c087aa96..2269427c1edc001edae94bdf7e9fce0e53fad2d8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
 *.pdf
-.data/
\ No newline at end of file
+.data/
+*.pyc
\ No newline at end of file
diff --git a/tf_model.ipynb b/tf_model.ipynb
index efcdfa79fba233262de74809a897831726664598..4404a8cf282f6d3e69e8e2a9782861a11f28b61f 100644
--- a/tf_model.ipynb
+++ b/tf_model.ipynb
@@ -7,9 +7,16 @@
     "# Notebook for training and testing AI models"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Setup"
+   ]
+  },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 3,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -23,7 +30,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 4,
    "metadata": {},
    "outputs": [
     {
@@ -40,7 +47,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 5,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -51,7 +58,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 6,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -234,7 +241,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 7,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -244,7 +251,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 8,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -259,7 +266,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 9,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -272,7 +279,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 10,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -282,7 +289,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 11,
    "metadata": {},
    "outputs": [
     {
@@ -318,7 +325,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 11,
+   "execution_count": 12,
    "metadata": {},
    "outputs": [
     {
@@ -595,7 +602,7 @@
        "[551042 rows x 13 columns]"
       ]
      },
-     "execution_count": 11,
+     "execution_count": 12,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -606,7 +613,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 12,
+   "execution_count": 13,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -616,7 +623,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 13,
+   "execution_count": 14,
    "metadata": {},
    "outputs": [
     {
@@ -660,6 +667,221 @@
     "print(sleep_data.info())"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>sleep_id</th>\n",
+       "      <th>minutes_since_begin</th>\n",
+       "      <th>stage_start_hour</th>\n",
+       "      <th>stage_start_minute</th>\n",
+       "      <th>awake_probability</th>\n",
+       "      <th>rem_probability</th>\n",
+       "      <th>light_probability</th>\n",
+       "      <th>deep_probability</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>0</td>\n",
+       "      <td>0</td>\n",
+       "      <td>8</td>\n",
+       "      <td>18</td>\n",
+       "      <td>1.0</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>0.0</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>0</td>\n",
+       "      <td>1</td>\n",
+       "      <td>8</td>\n",
+       "      <td>19</td>\n",
+       "      <td>1.0</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>0.0</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>0</td>\n",
+       "      <td>2</td>\n",
+       "      <td>8</td>\n",
+       "      <td>20</td>\n",
+       "      <td>1.0</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>0.0</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>0</td>\n",
+       "      <td>3</td>\n",
+       "      <td>8</td>\n",
+       "      <td>21</td>\n",
+       "      <td>1.0</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>0.0</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>4</th>\n",
+       "      <td>0</td>\n",
+       "      <td>4</td>\n",
+       "      <td>8</td>\n",
+       "      <td>22</td>\n",
+       "      <td>1.0</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>0.0</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>...</th>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "      <td>...</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>551037</th>\n",
+       "      <td>1132</td>\n",
+       "      <td>426</td>\n",
+       "      <td>13</td>\n",
+       "      <td>17</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>1.0</td>\n",
+       "      <td>0.0</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>551038</th>\n",
+       "      <td>1132</td>\n",
+       "      <td>427</td>\n",
+       "      <td>13</td>\n",
+       "      <td>18</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>1.0</td>\n",
+       "      <td>0.0</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>551039</th>\n",
+       "      <td>1132</td>\n",
+       "      <td>428</td>\n",
+       "      <td>13</td>\n",
+       "      <td>19</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>1.0</td>\n",
+       "      <td>0.0</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>551040</th>\n",
+       "      <td>1132</td>\n",
+       "      <td>429</td>\n",
+       "      <td>13</td>\n",
+       "      <td>20</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>1.0</td>\n",
+       "      <td>0.0</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>551041</th>\n",
+       "      <td>1132</td>\n",
+       "      <td>430</td>\n",
+       "      <td>13</td>\n",
+       "      <td>21</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>0.0</td>\n",
+       "      <td>1.0</td>\n",
+       "      <td>0.0</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "<p>551042 rows × 8 columns</p>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "        sleep_id  minutes_since_begin  stage_start_hour  stage_start_minute  \\\n",
+       "0              0                    0                 8                  18   \n",
+       "1              0                    1                 8                  19   \n",
+       "2              0                    2                 8                  20   \n",
+       "3              0                    3                 8                  21   \n",
+       "4              0                    4                 8                  22   \n",
+       "...          ...                  ...               ...                 ...   \n",
+       "551037      1132                  426                13                  17   \n",
+       "551038      1132                  427                13                  18   \n",
+       "551039      1132                  428                13                  19   \n",
+       "551040      1132                  429                13                  20   \n",
+       "551041      1132                  430                13                  21   \n",
+       "\n",
+       "        awake_probability  rem_probability  light_probability  \\\n",
+       "0                     1.0              0.0                0.0   \n",
+       "1                     1.0              0.0                0.0   \n",
+       "2                     1.0              0.0                0.0   \n",
+       "3                     1.0              0.0                0.0   \n",
+       "4                     1.0              0.0                0.0   \n",
+       "...                   ...              ...                ...   \n",
+       "551037                0.0              0.0                1.0   \n",
+       "551038                0.0              0.0                1.0   \n",
+       "551039                0.0              0.0                1.0   \n",
+       "551040                0.0              0.0                1.0   \n",
+       "551041                0.0              0.0                1.0   \n",
+       "\n",
+       "        deep_probability  \n",
+       "0                    0.0  \n",
+       "1                    0.0  \n",
+       "2                    0.0  \n",
+       "3                    0.0  \n",
+       "4                    0.0  \n",
+       "...                  ...  \n",
+       "551037               0.0  \n",
+       "551038               0.0  \n",
+       "551039               0.0  \n",
+       "551040               0.0  \n",
+       "551041               0.0  \n",
+       "\n",
+       "[551042 rows x 8 columns]"
+      ]
+     },
+     "execution_count": 15,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "sleep_data"
+   ]
+  },
   {
    "cell_type": "markdown",
    "metadata": {},
@@ -676,7 +898,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 14,
+   "execution_count": 16,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -688,7 +910,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 17,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -699,13 +921,13 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 16,
+   "execution_count": 31,
    "metadata": {},
    "outputs": [],
    "source": [
     "# Adapted from https://www.tensorflow.org/tutorials/structured_data/time_series\n",
     "class WindowGenerator():\n",
-    "    def __init__(self, data, index: str = \"sleep_id\", input_width: int = TIME_STEP_INPUT, validation_size: int = VALIDATION_SIZE, test_size: int = TEST_SIZE, generate_data_now: bool = True):\n",
+    "    def __init__(self, data, index: str = \"sleep_id\", input_width: int = TIME_STEP_INPUT, validation_size: int = VALIDATION_SIZE, test_size: int = TEST_SIZE, input_feature_slice: slice = slice(1,100), label_feature_slice: slice = slice(-4,100), generate_data_now: bool = True):\n",
     "        # Partition data\n",
     "        self.training, self.testing = training_test_split_by_unique_index(sleep_data, index, test_size)\n",
     "        self.training, self.validation = training_test_split_by_unique_index(self.training, index, validation_size)\n",
@@ -723,7 +945,12 @@
     "        self.label_start = self.total_window_size - self.label_width\n",
     "        self.labels_slice = slice(self.label_start, None)\n",
     "        self.label_indices = np.arange(self.total_window_size)[self.labels_slice]\n",
-    "        \n",
+    "\n",
+    "        self.input_feature_slice = input_feature_slice\n",
+    "        self.label_feature_slice = label_feature_slice\n",
+    "\n",
+    "        self.sample_ds = self.make_dataset(sleep_data[sleep_data[index] == 0])\n",
+    "\n",
     "        if generate_data_now:\n",
     "            self.training_ds = self.make_dataset(self.training)\n",
     "            self.validation_ds = self.make_dataset(self.validation)\n",
@@ -738,11 +965,10 @@
     "        ])\n",
     "\n",
     "    def split_window(self, features):\n",
-    "        inputs = features[:, self.input_slice, :]\n",
-    "        labels = features[:, self.labels_slice, :]\n",
+    "        inputs = features[:, self.input_slice, self.input_feature_slice]\n",
+    "        labels = features[:, self.labels_slice, self.label_feature_slice]\n",
     "        inputs.set_shape([None, self.input_width, None])\n",
     "        labels.set_shape([None, self.label_width, None])\n",
-    "\n",
     "        return inputs, labels\n",
     "\n",
     "    def make_dataset(self, data, index_group: str = \"sleep_id\", sort_by: str = \"minutes_since_begin\"):\n",
@@ -782,58 +1008,6 @@
     "    #     return self._testing_ds"
    ]
   },
-  {
-   "cell_type": "code",
-   "execution_count": 17,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "2022-05-11 00:38:01.423873: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
-      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
-     ]
-    },
-    {
-     "data": {
-      "text/plain": [
-       "WindowGenerator:\n",
-       "\tTotal window size: 11\n",
-       "\tInput indices: [0 1 2 3 4 5 6 7 8 9]\n",
-       "\tLabel indices: [10]"
-      ]
-     },
-     "execution_count": 17,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "wg = WindowGenerator(sleep_data)\n",
-    "wg"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 20,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "1836"
-      ]
-     },
-     "execution_count": 20,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "len(wg.testing_ds)"
-   ]
-  },
   {
    "cell_type": "markdown",
    "metadata": {},
@@ -845,92 +1019,75 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {},
-   "outputs": [
-    {
-     "ename": "NameError",
-     "evalue": "name 'training_test_split_by_unique_index' is not defined",
-     "output_type": "error",
-     "traceback": [
-      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
-      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
-      "\u001b[1;32m/Users/nowadmin/Documents/School Folder/CS 437/Lab/Final Project/tf_model.ipynb Cell 32'\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> <a href='vscode-notebook-cell:/Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/tf_model.ipynb#ch0000068?line=0'>1</a>\u001b[0m training_sleep_data, test_sleep_data \u001b[39m=\u001b[39m training_test_split_by_unique_index(sleep_data, \u001b[39m\"\u001b[39m\u001b[39msleep_id\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m5\u001b[39m)\n",
-      "\u001b[0;31mNameError\u001b[0m: name 'training_test_split_by_unique_index' is not defined"
-     ]
-    }
-   ],
-   "source": [
-    "training_sleep_data, test_sleep_data = training_test_split_by_unique_index(sleep_data, \"sleep_id\", 5)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 288,
+   "execution_count": 32,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "set()"
+       "WindowGenerator:\n",
+       "\tTotal window size: 11\n",
+       "\tInput indices: [0 1 2 3 4 5 6 7 8 9]\n",
+       "\tLabel indices: [10]"
       ]
      },
-     "execution_count": 288,
+     "execution_count": 32,
      "metadata": {},
      "output_type": "execute_result"
     }
    ],
    "source": [
-    "set(training.sleep_id.unique()).intersection(set(test.sleep_id.unique()))"
+    "wg = WindowGenerator(sleep_data)\n",
+    "wg"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 239,
+   "execution_count": 33,
    "metadata": {},
    "outputs": [],
    "source": [
-    "sleep_data_tensor = tf.convert_to_tensor(sleep_data)"
+    "sample = wg.sample_ds.take(1)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 240,
+   "execution_count": 34,
    "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "TensorShape([551042, 7])"
-      ]
-     },
-     "execution_count": 240,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
+   "outputs": [],
    "source": [
-    "sleep_data_tensor.shape"
+    "sample_array = list(sample.as_numpy_iterator())"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 241,
+   "execution_count": 37,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "<tf.Tensor: shape=(7,), dtype=float64, numpy=array([ 1.,  8., 19.,  1.,  0.,  0.,  0.])>"
+       "(array([[18.,  8., 36.,  1.,  0.,  0.,  0.],\n",
+       "        [19.,  8., 37.,  1.,  0.,  0.,  0.],\n",
+       "        [20.,  8., 38.,  1.,  0.,  0.,  0.],\n",
+       "        [21.,  8., 39.,  1.,  0.,  0.,  0.],\n",
+       "        [22.,  8., 40.,  1.,  0.,  0.,  0.],\n",
+       "        [23.,  8., 41.,  1.,  0.,  0.,  0.],\n",
+       "        [24.,  8., 42.,  1.,  0.,  0.,  0.],\n",
+       "        [25.,  8., 43.,  1.,  0.,  0.,  0.],\n",
+       "        [26.,  8., 44.,  1.,  0.,  0.,  0.],\n",
+       "        [27.,  8., 45.,  1.,  0.,  0.,  0.]], dtype=float32),\n",
+       " array([[0., 0., 1., 0.]], dtype=float32))"
       ]
      },
-     "execution_count": 241,
+     "execution_count": 37,
      "metadata": {},
      "output_type": "execute_result"
     }
    ],
    "source": [
-    "sleep_data_tensor[1]"
+    "INDEX_TIMESTEP = 18\n",
+    "sample_array[0][0][INDEX_TIMESTEP], sample_array[0][1][INDEX_TIMESTEP]"
    ]
   },
   {