{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial: Adding a new image-classification model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import distiller\n",
    "import torch.nn as nn\n",
    "from distiller.models import register_user_model\n",
    "import distiller.apputils.image_classifier as classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyModel(nn.Module): \n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 20, 5, 1)\n",
    "        self.relu1 = nn.ReLU(inplace=False)\n",
    "        self.pool1 = nn.MaxPool2d(2, 2)\n",
    "        self.conv2 = nn.Conv2d(20, 50, 5, 1)\n",
    "        self.relu2 = nn.ReLU(inplace=False)\n",
    "        self.pool2 = nn.MaxPool2d(2, 2)\n",
    "        self.avgpool = nn.AvgPool2d(4, stride=1)\n",
    "        self.fc = nn.Linear(50, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.pool1(self.relu1(self.conv1(x)))\n",
    "        x = self.pool2(self.relu2(self.conv2(x)))\n",
    "        x = self.avgpool(x)\n",
    "        x = x.view(x.size(0), -1)\n",
    "        x = self.fc(x)\n",
    "        return x\n",
    "\n",
    "def my_model():\n",
    "    return MyModel()\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'params_nnz_cnt': -26000.0, 'sparsity': 0.0, 'top1': 92.80000000000001, 'top5': 99.63333333333333, 'epoch': 0}]\n"
     ]
    }
   ],
   "source": [
    "distiller.models.register_user_model(arch=\"MyModel\", dataset=\"mnist\", model=my_model)\n",
    "model = distiller.models.create_model(pretrained=True, dataset=\"mnist\", arch=\"MyModel\")\n",
    "assert model is not None\n",
    "\n",
    "\n",
    "def init_jupyter_default_args(args):\n",
    "    args.output_dir = None\n",
    "    args.evaluate = False\n",
    "    args.seed = None\n",
    "    args.deterministic = False\n",
    "    args.cpu = True\n",
    "    args.gpus = None\n",
    "    args.load_serialized = False\n",
    "    args.deprecated_resume = None\n",
    "    args.resumed_checkpoint_path = None\n",
    "    args.load_model_path = None\n",
    "    args.reset_optimizer = False\n",
    "    args.lr = args.momentum = args.weight_decay = 0.\n",
    "    args.compress = None\n",
    "    args.epochs = 0\n",
    "    args.activation_stats = list()\n",
    "    args.batch_size = 1\n",
    "    args.workers = 1\n",
    "    args.validation_split = 0.1\n",
    "    args.effective_train_size = args.effective_valid_size = args.effective_test_size = 1.\n",
    "    args.log_params_histograms = False\n",
    "    args.print_freq = 1\n",
    "    args.masks_sparsity = False\n",
    "    args.display_confusion = False\n",
    "    args.num_best_scores = 1\n",
    "    args.name = \"\"\n",
    "\n",
    "\n",
    "def config_learner_args(args, arch, dataset, dataset_path, pretrained, sgd_args, batch, epochs):\n",
    "    args.arch = \"MyModel\"\n",
    "    args.dataset = \"mnist\"\n",
    "    args.data = \"/datasets/mnist/\"\n",
    "    args.pretrained = False\n",
    "    args.lr = sgd_args[0]\n",
    "    args.momentum = sgd_args[1]\n",
    "    args.weight_decay = sgd_args[2]\n",
    "    args.batch_size = 256\n",
    "    args.epochs = epochs\n",
    "\n",
    "args = classifier.init_classifier_compression_arg_parser()\n",
    "init_jupyter_default_args(args)\n",
    "config_learner_args(args, \"MyModel\", \"mnist\", \"/datasets/mnist/\", False, (0.1, 0.9, 1e-4) , 256, 1)\n",
    "app = classifier.ClassifierCompressor(args, script_dir=os.path.dirname(\".\"))\n",
    "\n",
    "# Run the training loop\n",
    "perf_scores_history = app.run_training_loop()\n",
    "print(perf_scores_history)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}