From acb172711dcaf2736b81be19f4ab3755a6c15d2b Mon Sep 17 00:00:00 2001
From: SurajSSingh <surajss@uci.edu>
Date: Wed, 11 May 2022 12:00:23 -0700
Subject: [PATCH] Added Export Simple Data

---
 tf_model.ipynb | 339 ++++++++++++++++++++++++++++++++-----------------
 1 file changed, 222 insertions(+), 117 deletions(-)

diff --git a/tf_model.ipynb b/tf_model.ipynb
index 4404a8c..0043c5f 100644
--- a/tf_model.ipynb
+++ b/tf_model.ipynb
@@ -16,7 +16,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 1,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -30,7 +30,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 2,
    "metadata": {},
    "outputs": [
     {
@@ -47,7 +47,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 3,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -58,7 +58,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 4,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -241,7 +241,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 5,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -251,7 +251,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 6,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -266,7 +266,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": 7,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -279,7 +279,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 8,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -289,7 +289,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 11,
+   "execution_count": 9,
    "metadata": {},
    "outputs": [
     {
@@ -325,7 +325,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 12,
+   "execution_count": 10,
    "metadata": {},
    "outputs": [
     {
@@ -602,7 +602,7 @@
        "[551042 rows x 13 columns]"
       ]
      },
-     "execution_count": 12,
+     "execution_count": 10,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -613,7 +613,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 13,
+   "execution_count": 37,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -623,7 +623,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 14,
+   "execution_count": 32,
    "metadata": {},
    "outputs": [
     {
@@ -669,7 +669,54 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": 38,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "sleep_data.to_csv(\".data/sleep_data_simple.csv\", index=False, index_label=False)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Model Development"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Setup"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 22,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "TEST_SIZE = 122\n",
+    "VALIDATION_SIZE = 183\n",
+    "\n",
+    "BATCH_SIZE = 32\n",
+    "INPUT_TIME_STEP = 10 # in minutes\n",
+    "INPUT_FEATURES_SIZE = 7\n",
+    "MAX_EPOCHS = 20"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 34,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "sleep_data = pd.read_csv(\".data/sleep_data_simple.csv\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 35,
    "metadata": {},
    "outputs": [
     {
@@ -693,6 +740,7 @@
        "  <thead>\n",
        "    <tr style=\"text-align: right;\">\n",
        "      <th></th>\n",
+       "      <th>Unnamed: 0</th>\n",
        "      <th>sleep_id</th>\n",
        "      <th>minutes_since_begin</th>\n",
        "      <th>stage_start_hour</th>\n",
@@ -708,6 +756,7 @@
        "      <th>0</th>\n",
        "      <td>0</td>\n",
        "      <td>0</td>\n",
+       "      <td>0</td>\n",
        "      <td>8</td>\n",
        "      <td>18</td>\n",
        "      <td>1.0</td>\n",
@@ -717,6 +766,7 @@
        "    </tr>\n",
        "    <tr>\n",
        "      <th>1</th>\n",
+       "      <td>1</td>\n",
        "      <td>0</td>\n",
        "      <td>1</td>\n",
        "      <td>8</td>\n",
@@ -728,6 +778,7 @@
        "    </tr>\n",
        "    <tr>\n",
        "      <th>2</th>\n",
+       "      <td>2</td>\n",
        "      <td>0</td>\n",
        "      <td>2</td>\n",
        "      <td>8</td>\n",
@@ -739,6 +790,7 @@
        "    </tr>\n",
        "    <tr>\n",
        "      <th>3</th>\n",
+       "      <td>3</td>\n",
        "      <td>0</td>\n",
        "      <td>3</td>\n",
        "      <td>8</td>\n",
@@ -750,6 +802,7 @@
        "    </tr>\n",
        "    <tr>\n",
        "      <th>4</th>\n",
+       "      <td>4</td>\n",
        "      <td>0</td>\n",
        "      <td>4</td>\n",
        "      <td>8</td>\n",
@@ -769,9 +822,11 @@
        "      <td>...</td>\n",
        "      <td>...</td>\n",
        "      <td>...</td>\n",
+       "      <td>...</td>\n",
        "    </tr>\n",
        "    <tr>\n",
        "      <th>551037</th>\n",
+       "      <td>551037</td>\n",
        "      <td>1132</td>\n",
        "      <td>426</td>\n",
        "      <td>13</td>\n",
@@ -783,6 +838,7 @@
        "    </tr>\n",
        "    <tr>\n",
        "      <th>551038</th>\n",
+       "      <td>551038</td>\n",
        "      <td>1132</td>\n",
        "      <td>427</td>\n",
        "      <td>13</td>\n",
@@ -794,6 +850,7 @@
        "    </tr>\n",
        "    <tr>\n",
        "      <th>551039</th>\n",
+       "      <td>551039</td>\n",
        "      <td>1132</td>\n",
        "      <td>428</td>\n",
        "      <td>13</td>\n",
@@ -805,6 +862,7 @@
        "    </tr>\n",
        "    <tr>\n",
        "      <th>551040</th>\n",
+       "      <td>551040</td>\n",
        "      <td>1132</td>\n",
        "      <td>429</td>\n",
        "      <td>13</td>\n",
@@ -816,6 +874,7 @@
        "    </tr>\n",
        "    <tr>\n",
        "      <th>551041</th>\n",
+       "      <td>551041</td>\n",
        "      <td>1132</td>\n",
        "      <td>430</td>\n",
        "      <td>13</td>\n",
@@ -827,53 +886,53 @@
        "    </tr>\n",
        "  </tbody>\n",
        "</table>\n",
-       "<p>551042 rows × 8 columns</p>\n",
+       "<p>551042 rows × 9 columns</p>\n",
        "</div>"
       ],
       "text/plain": [
-       "        sleep_id  minutes_since_begin  stage_start_hour  stage_start_minute  \\\n",
-       "0              0                    0                 8                  18   \n",
-       "1              0                    1                 8                  19   \n",
-       "2              0                    2                 8                  20   \n",
-       "3              0                    3                 8                  21   \n",
-       "4              0                    4                 8                  22   \n",
-       "...          ...                  ...               ...                 ...   \n",
-       "551037      1132                  426                13                  17   \n",
-       "551038      1132                  427                13                  18   \n",
-       "551039      1132                  428                13                  19   \n",
-       "551040      1132                  429                13                  20   \n",
-       "551041      1132                  430                13                  21   \n",
+       "        Unnamed: 0  sleep_id  minutes_since_begin  stage_start_hour  \\\n",
+       "0                0         0                    0                 8   \n",
+       "1                1         0                    1                 8   \n",
+       "2                2         0                    2                 8   \n",
+       "3                3         0                    3                 8   \n",
+       "4                4         0                    4                 8   \n",
+       "...            ...       ...                  ...               ...   \n",
+       "551037      551037      1132                  426                13   \n",
+       "551038      551038      1132                  427                13   \n",
+       "551039      551039      1132                  428                13   \n",
+       "551040      551040      1132                  429                13   \n",
+       "551041      551041      1132                  430                13   \n",
        "\n",
-       "        awake_probability  rem_probability  light_probability  \\\n",
-       "0                     1.0              0.0                0.0   \n",
-       "1                     1.0              0.0                0.0   \n",
-       "2                     1.0              0.0                0.0   \n",
-       "3                     1.0              0.0                0.0   \n",
-       "4                     1.0              0.0                0.0   \n",
-       "...                   ...              ...                ...   \n",
-       "551037                0.0              0.0                1.0   \n",
-       "551038                0.0              0.0                1.0   \n",
-       "551039                0.0              0.0                1.0   \n",
-       "551040                0.0              0.0                1.0   \n",
-       "551041                0.0              0.0                1.0   \n",
+       "        stage_start_minute  awake_probability  rem_probability  \\\n",
+       "0                       18                1.0              0.0   \n",
+       "1                       19                1.0              0.0   \n",
+       "2                       20                1.0              0.0   \n",
+       "3                       21                1.0              0.0   \n",
+       "4                       22                1.0              0.0   \n",
+       "...                    ...                ...              ...   \n",
+       "551037                  17                0.0              0.0   \n",
+       "551038                  18                0.0              0.0   \n",
+       "551039                  19                0.0              0.0   \n",
+       "551040                  20                0.0              0.0   \n",
+       "551041                  21                0.0              0.0   \n",
        "\n",
-       "        deep_probability  \n",
-       "0                    0.0  \n",
-       "1                    0.0  \n",
-       "2                    0.0  \n",
-       "3                    0.0  \n",
-       "4                    0.0  \n",
-       "...                  ...  \n",
-       "551037               0.0  \n",
-       "551038               0.0  \n",
-       "551039               0.0  \n",
-       "551040               0.0  \n",
-       "551041               0.0  \n",
+       "        light_probability  deep_probability  \n",
+       "0                     0.0               0.0  \n",
+       "1                     0.0               0.0  \n",
+       "2                     0.0               0.0  \n",
+       "3                     0.0               0.0  \n",
+       "4                     0.0               0.0  \n",
+       "...                   ...               ...  \n",
+       "551037                1.0               0.0  \n",
+       "551038                1.0               0.0  \n",
+       "551039                1.0               0.0  \n",
+       "551040                1.0               0.0  \n",
+       "551041                1.0               0.0  \n",
        "\n",
-       "[551042 rows x 8 columns]"
+       "[551042 rows x 9 columns]"
       ]
      },
-     "execution_count": 15,
+     "execution_count": 35,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -882,13 +941,6 @@
     "sleep_data"
    ]
   },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "## Model Development"
-   ]
-  },
   {
    "cell_type": "markdown",
    "metadata": {},
@@ -896,18 +948,6 @@
     "### Helper functions and class"
    ]
   },
-  {
-   "cell_type": "code",
-   "execution_count": 16,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "TEST_SIZE = 122\n",
-    "VALIDATION_SIZE = 183\n",
-    "TIME_STEP_INPUT = 10 # in minutes\n",
-    "BATCH_SIZE = 32"
-   ]
-  },
   {
    "cell_type": "code",
    "execution_count": 17,
@@ -921,13 +961,13 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 31,
+   "execution_count": 24,
    "metadata": {},
    "outputs": [],
    "source": [
     "# Adapted from https://www.tensorflow.org/tutorials/structured_data/time_series\n",
     "class WindowGenerator():\n",
-    "    def __init__(self, data, index: str = \"sleep_id\", input_width: int = TIME_STEP_INPUT, validation_size: int = VALIDATION_SIZE, test_size: int = TEST_SIZE, input_feature_slice: slice = slice(1,100), label_feature_slice: slice = slice(-4,100), generate_data_now: bool = True):\n",
+    "    def __init__(self, data, index: str = \"sleep_id\", input_width: int = INPUT_TIME_STEP, validation_size: int = VALIDATION_SIZE, test_size: int = TEST_SIZE, input_feature_slice: slice = slice(1,100), label_feature_slice: slice = slice(-4,100), generate_data_now: bool = True):\n",
     "        # Partition data\n",
     "        self.training, self.testing = training_test_split_by_unique_index(sleep_data, index, test_size)\n",
     "        self.training, self.validation = training_test_split_by_unique_index(self.training, index, validation_size)\n",
@@ -952,9 +992,9 @@
     "        self.sample_ds = self.make_dataset(sleep_data[sleep_data[index] == 0])\n",
     "\n",
     "        if generate_data_now:\n",
-    "            self.training_ds = self.make_dataset(self.training)\n",
-    "            self.validation_ds = self.make_dataset(self.validation)\n",
-    "            self.testing_ds = self.make_dataset(self.testing)\n",
+    "            self.training_ds = self.make_dataset(self.training, index)\n",
+    "            self.validation_ds = self.make_dataset(self.validation, index)\n",
+    "            self.testing_ds = self.make_dataset(self.testing, index)\n",
     "\n",
     "\n",
     "    def __repr__(self):\n",
@@ -966,9 +1006,9 @@
     "\n",
     "    def split_window(self, features):\n",
     "        inputs = features[:, self.input_slice, self.input_feature_slice]\n",
-    "        labels = features[:, self.labels_slice, self.label_feature_slice]\n",
+    "        labels = tf.squeeze(features[:, self.labels_slice, self.label_feature_slice])\n",
     "        inputs.set_shape([None, self.input_width, None])\n",
-    "        labels.set_shape([None, self.label_width, None])\n",
+    "        # labels.set_shape([None, self.label_width, None])\n",
     "        return inputs, labels\n",
     "\n",
     "    def make_dataset(self, data, index_group: str = \"sleep_id\", sort_by: str = \"minutes_since_begin\"):\n",
@@ -1019,21 +1059,31 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 32,
+   "execution_count": 25,
    "metadata": {},
    "outputs": [
     {
-     "data": {
-      "text/plain": [
-       "WindowGenerator:\n",
-       "\tTotal window size: 11\n",
-       "\tInput indices: [0 1 2 3 4 5 6 7 8 9]\n",
-       "\tLabel indices: [10]"
-      ]
-     },
-     "execution_count": 32,
-     "metadata": {},
-     "output_type": "execute_result"
+     "ename": "KeyError",
+     "evalue": "'sleep_id'",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
+      "File \u001b[0;32m~/Documents/School Folder/CS 437/Lab/Final Project/venv/lib/python3.10/site-packages/pandas/core/indexes/base.py:3621\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key, method, tolerance)\u001b[0m\n\u001b[1;32m   <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/pandas/core/indexes/base.py?line=3619'>3620</a>\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m-> <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/pandas/core/indexes/base.py?line=3620'>3621</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_engine\u001b[39m.\u001b[39;49mget_loc(casted_key)\n\u001b[1;32m   <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/pandas/core/indexes/base.py?line=3621'>3622</a>\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mKeyError\u001b[39;00m \u001b[39mas\u001b[39;00m err:\n",
+      "File \u001b[0;32mpandas/_libs/index.pyx:136\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n",
+      "File \u001b[0;32mpandas/_libs/index.pyx:163\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n",
+      "File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:5198\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n",
+      "File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:5206\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n",
+      "\u001b[0;31mKeyError\u001b[0m: 'sleep_id'",
+      "\nThe above exception was the direct cause of the following exception:\n",
+      "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
+      "\u001b[1;32m/Users/nowadmin/Documents/School Folder/CS 437/Lab/Final Project/tf_model.ipynb Cell 33'\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> <a href='vscode-notebook-cell:/Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/tf_model.ipynb#ch0000072?line=0'>1</a>\u001b[0m wg \u001b[39m=\u001b[39m WindowGenerator(sleep_data)\n\u001b[1;32m      <a href='vscode-notebook-cell:/Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/tf_model.ipynb#ch0000072?line=1'>2</a>\u001b[0m wg\n",
+      "\u001b[1;32m/Users/nowadmin/Documents/School Folder/CS 437/Lab/Final Project/tf_model.ipynb Cell 31'\u001b[0m in \u001b[0;36mWindowGenerator.__init__\u001b[0;34m(self, data, index, input_width, validation_size, test_size, input_feature_slice, label_feature_slice, generate_data_now)\u001b[0m\n\u001b[1;32m      <a href='vscode-notebook-cell:/Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/tf_model.ipynb#ch0000065?line=2'>3</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, data, index: \u001b[39mstr\u001b[39m \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39msleep_id\u001b[39m\u001b[39m\"\u001b[39m, input_width: \u001b[39mint\u001b[39m \u001b[39m=\u001b[39m INPUT_TIME_STEP, validation_size: \u001b[39mint\u001b[39m \u001b[39m=\u001b[39m VALIDATION_SIZE, test_size: \u001b[39mint\u001b[39m \u001b[39m=\u001b[39m TEST_SIZE, input_feature_slice: \u001b[39mslice\u001b[39m \u001b[39m=\u001b[39m \u001b[39mslice\u001b[39m(\u001b[39m1\u001b[39m,\u001b[39m100\u001b[39m), label_feature_slice: \u001b[39mslice\u001b[39m \u001b[39m=\u001b[39m \u001b[39mslice\u001b[39m(\u001b[39m-\u001b[39m\u001b[39m4\u001b[39m,\u001b[39m100\u001b[39m), generate_data_now: \u001b[39mbool\u001b[39m \u001b[39m=\u001b[39m \u001b[39mTrue\u001b[39;00m):\n\u001b[1;32m      <a href='vscode-notebook-cell:/Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/tf_model.ipynb#ch0000065?line=3'>4</a>\u001b[0m     \u001b[39m# Partition data\u001b[39;00m\n\u001b[0;32m----> <a href='vscode-notebook-cell:/Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/tf_model.ipynb#ch0000065?line=4'>5</a>\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtraining, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtesting \u001b[39m=\u001b[39m training_test_split_by_unique_index(sleep_data, index, test_size)\n\u001b[1;32m      <a href='vscode-notebook-cell:/Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/tf_model.ipynb#ch0000065?line=5'>6</a>\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtraining, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mvalidation \u001b[39m=\u001b[39m training_test_split_by_unique_index(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtraining, index, validation_size)\n\u001b[1;32m      <a href='vscode-notebook-cell:/Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/tf_model.ipynb#ch0000065?line=7'>8</a>\u001b[0m     \u001b[39m# Window paramters\u001b[39;00m\n",
+      "\u001b[1;32m/Users/nowadmin/Documents/School Folder/CS 437/Lab/Final Project/tf_model.ipynb Cell 30'\u001b[0m in \u001b[0;36mtraining_test_split_by_unique_index\u001b[0;34m(data, index, test_size)\u001b[0m\n\u001b[1;32m      <a href='vscode-notebook-cell:/Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/tf_model.ipynb#ch0000066?line=0'>1</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mtraining_test_split_by_unique_index\u001b[39m(data, index: \u001b[39mstr\u001b[39m, test_size: \u001b[39mint\u001b[39m \u001b[39m=\u001b[39m \u001b[39m10\u001b[39m):\n\u001b[0;32m----> <a href='vscode-notebook-cell:/Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/tf_model.ipynb#ch0000066?line=1'>2</a>\u001b[0m     test_ids \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39mrandom\u001b[39m.\u001b[39mchoice(data[index]\u001b[39m.\u001b[39munique(), size \u001b[39m=\u001b[39m test_size, replace\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m)\n\u001b[1;32m      <a href='vscode-notebook-cell:/Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/tf_model.ipynb#ch0000066?line=2'>3</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m data[\u001b[39m~\u001b[39mdata[index]\u001b[39m.\u001b[39misin(test_ids)], data[data[index]\u001b[39m.\u001b[39misin(test_ids)]\n",
+      "File \u001b[0;32m~/Documents/School Folder/CS 437/Lab/Final Project/venv/lib/python3.10/site-packages/pandas/core/frame.py:3505\u001b[0m, in \u001b[0;36mDataFrame.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m   <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/pandas/core/frame.py?line=3502'>3503</a>\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcolumns\u001b[39m.\u001b[39mnlevels \u001b[39m>\u001b[39m \u001b[39m1\u001b[39m:\n\u001b[1;32m   <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/pandas/core/frame.py?line=3503'>3504</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_getitem_multilevel(key)\n\u001b[0;32m-> <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/pandas/core/frame.py?line=3504'>3505</a>\u001b[0m indexer \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcolumns\u001b[39m.\u001b[39;49mget_loc(key)\n\u001b[1;32m   <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/pandas/core/frame.py?line=3505'>3506</a>\u001b[0m \u001b[39mif\u001b[39;00m is_integer(indexer):\n\u001b[1;32m   <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/pandas/core/frame.py?line=3506'>3507</a>\u001b[0m     indexer \u001b[39m=\u001b[39m [indexer]\n",
+      "File \u001b[0;32m~/Documents/School Folder/CS 437/Lab/Final Project/venv/lib/python3.10/site-packages/pandas/core/indexes/base.py:3623\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key, method, tolerance)\u001b[0m\n\u001b[1;32m   <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/pandas/core/indexes/base.py?line=3620'>3621</a>\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_engine\u001b[39m.\u001b[39mget_loc(casted_key)\n\u001b[1;32m   <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/pandas/core/indexes/base.py?line=3621'>3622</a>\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mKeyError\u001b[39;00m \u001b[39mas\u001b[39;00m err:\n\u001b[0;32m-> <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/pandas/core/indexes/base.py?line=3622'>3623</a>\u001b[0m     \u001b[39mraise\u001b[39;00m \u001b[39mKeyError\u001b[39;00m(key) \u001b[39mfrom\u001b[39;00m \u001b[39merr\u001b[39;00m\n\u001b[1;32m   <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/pandas/core/indexes/base.py?line=3623'>3624</a>\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mTypeError\u001b[39;00m:\n\u001b[1;32m   <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/pandas/core/indexes/base.py?line=3624'>3625</a>\u001b[0m     \u001b[39m# If we have a listlike key, _check_indexing_error will raise\u001b[39;00m\n\u001b[1;32m   <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/pandas/core/indexes/base.py?line=3625'>3626</a>\u001b[0m     \u001b[39m#  InvalidIndexError. Otherwise we fall through and re-raise\u001b[39;00m\n\u001b[1;32m   <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/pandas/core/indexes/base.py?line=3626'>3627</a>\u001b[0m     \u001b[39m#  the TypeError.\u001b[39;00m\n\u001b[1;32m   <a href='file:///Users/nowadmin/Documents/School%20Folder/CS%20437/Lab/Final%20Project/venv/lib/python3.10/site-packages/pandas/core/indexes/base.py?line=3627'>3628</a>\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_check_indexing_error(key)\n",
+      "\u001b[0;31mKeyError\u001b[0m: 'sleep_id'"
+     ]
     }
    ],
    "source": [
@@ -1043,7 +1093,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 33,
+   "execution_count": 60,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1052,7 +1102,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 34,
+   "execution_count": 61,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1061,26 +1111,16 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 37,
+   "execution_count": 63,
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "(array([[18.,  8., 36.,  1.,  0.,  0.,  0.],\n",
-       "        [19.,  8., 37.,  1.,  0.,  0.,  0.],\n",
-       "        [20.,  8., 38.,  1.,  0.,  0.,  0.],\n",
-       "        [21.,  8., 39.,  1.,  0.,  0.,  0.],\n",
-       "        [22.,  8., 40.,  1.,  0.,  0.,  0.],\n",
-       "        [23.,  8., 41.,  1.,  0.,  0.,  0.],\n",
-       "        [24.,  8., 42.,  1.,  0.,  0.,  0.],\n",
-       "        [25.,  8., 43.,  1.,  0.,  0.,  0.],\n",
-       "        [26.,  8., 44.,  1.,  0.,  0.,  0.],\n",
-       "        [27.,  8., 45.,  1.,  0.,  0.,  0.]], dtype=float32),\n",
-       " array([[0., 0., 1., 0.]], dtype=float32))"
+       "((10, 7), (4,))"
       ]
      },
-     "execution_count": 37,
+     "execution_count": 63,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -1090,6 +1130,36 @@
     "sample_array[0][0][INDEX_TIMESTEP], sample_array[0][1][INDEX_TIMESTEP]"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### General Model Helper"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 64,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Adapted from https://www.tensorflow.org/tutorials/structured_data/time_series#linear_model\n",
+    "def compile_and_fit(model, window: WindowGenerator, loss = tf.losses.MeanSquaredError(), optimizer = tf.optimizers.Adam(), metrics = tf.metrics.MeanAbsoluteError(), patience:int = 2, epochs: int = MAX_EPOCHS):\n",
+    "    early_stopping = tf.keras.callbacks.EarlyStopping(\n",
+    "        monitor='val_loss',\n",
+    "        patience=patience,\n",
+    "        mode='min'\n",
+    "    )\n",
+    "\n",
+    "    model.compile(\n",
+    "        loss=loss,\n",
+    "        optimizer=optimizer,\n",
+    "        metrics=metrics,\n",
+    "    )\n",
+    "\n",
+    "    return model.fit(window.training_ds, epochs=epochs, validation_data=window.validation_ds, callbacks=[early_stopping])"
+   ]
+  },
   {
    "cell_type": "markdown",
    "metadata": {},
@@ -1099,7 +1169,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 215,
+   "execution_count": 65,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1110,24 +1180,24 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 278,
+   "execution_count": 66,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Model: \"sequential_12\"\n",
+      "Model: \"sequential_7\"\n",
       "_________________________________________________________________\n",
       " Layer (type)                Output Shape              Param #   \n",
       "=================================================================\n",
-      " lstm_14 (LSTM)              (None, 10, 16)            1600      \n",
+      " lstm_8 (LSTM)               (None, 16)                1536      \n",
       "                                                                 \n",
-      " dense_9 (Dense)             (None, 10, 4)             68        \n",
+      " dense_7 (Dense)             (None, 4)                 68        \n",
       "                                                                 \n",
       "=================================================================\n",
-      "Total params: 1,668\n",
-      "Trainable params: 1,668\n",
+      "Total params: 1,604\n",
+      "Trainable params: 1,604\n",
       "Non-trainable params: 0\n",
       "_________________________________________________________________\n",
       "None\n"
@@ -1137,9 +1207,10 @@
    "source": [
     "# Model Definition\n",
     "lstm_model = keras.Sequential()\n",
-    "lstm_model.add(layers.Input(shape=(TIME_STEP_INPUT, 8)))\n",
-    "lstm_model.add(layers.LSTM(LSTM_UNITS, stateful=False, return_sequences=True))\n",
+    "lstm_model.add(layers.Input(shape=(INPUT_TIME_STEP, INPUT_FEATURES_SIZE)))\n",
     "# lstm_model.add(layers.LSTM(LSTM_UNITS, stateful=False, return_sequences=True))\n",
+    "# lstm_model.add(layers.LSTM(LSTM_UNITS, stateful=False, return_sequences=True))\n",
+    "lstm_model.add(layers.LSTM(LSTM_UNITS, stateful=False, return_sequences=False))\n",
     "lstm_model.add(layers.Dense(SLEEP_STAGES))\n",
     "lstm_model.build()\n",
     "print(lstm_model.summary())"
@@ -1147,13 +1218,47 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 67,
    "metadata": {},
    "outputs": [],
    "source": [
     "# Model Training\n",
     "lstm_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
-    "lstm_optm = keras.optimizers.Adam(learning_rate=LSTM_LEARNING_RATE)"
+    "lstm_optm = keras.optimizers.Adam(learning_rate=LSTM_LEARNING_RATE)\n",
+    "lstm_metrics = [tf.keras.metrics.SparseCategoricalCrossentropy(from_logits=True), tf.keras.metrics.Accuracy()]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 68,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch 1/20\n"
+     ]
+    },
+    {
+     "ename": "",
+     "evalue": "",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[1;31mCanceled future for execute_request message before replies were done"
+     ]
+    },
+    {
+     "ename": "",
+     "evalue": "",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
+     ]
+    }
+   ],
+   "source": [
+    "lstm_history = compile_and_fit(model=lstm_model, window=wg, loss= lstm_loss, optimizer= lstm_optm, metrics=lstm_metrics)"
    ]
   },
   {
-- 
GitLab