diff --git a/jupyter/compression_insights.ipynb b/jupyter/compression_insights.ipynb index abbd6918c5bf85cfd185c2bb095bdac36c36b380..70a787325eb83c67401e8372b83fd5a2d4596e00 100755 --- a/jupyter/compression_insights.ipynb +++ b/jupyter/compression_insights.ipynb @@ -413,16 +413,119 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [], "source": [ - "params_names = conv_param_names(sparse_model)\n", - "\n", - "def view_kernel_l1(pname):\n", - " tensor = sparse_model.state_dict()[pname]\n", - " plot_l1_norm_hist(tensor)\n", + "def get_norms(weights, structure='kernels'):\n", + " \"\"\"Compute a histogram of the L1-norms of the kernels of a weights tensor.\n", + " \n", + " The L1-norm of a kernel is one way to quantify the \"magnitude\" of the total coeffiecients\n", + " making up this kernel.\n", " \n", - "interact(view_kernel_l1, pname=params_dropdown);" + " Another interesting look at filters is to compute a histogram per filter.\n", + " \"\"\"\n", + " ofms, ifms, kw, kh = weights.size()\n", + " if structure == 'kernels':\n", + " groups = weights.view(ofms * ifms, kh, kw)\n", + " elif structure == 'filters':\n", + " groups = weights\n", + " else:\n", + " raise ValueError('illegal structure')\n", + " \n", + " norms = [[], []]\n", + " groups = groups.view(groups.shape[0], -1)\n", + " group_size = groups.shape[1]\n", + " norms[0] = groups.norm(1, dim=1).div(group_size)\n", + " norms[1] = groups.norm(2, dim=1).div(group_size)\n", + " return norms\n", + "\n", + "def plot_l1_norm_hist(weights, structure):\n", + " norms = get_norms(weights, structure)\n", + " if structure == 'kernels':\n", + " bins = 200\n", + " else:\n", + " bins = 200\n", + " bins = None\n", + " n, bins, patches = plt.hist(norms[0], bins=bins, alpha=0.5)\n", + " plt.title('{} L1-norm histograms'.format(structure))\n", + " plt.ylabel('Frequency')\n", + " plt.xlabel('{} L1-norm'.format(structure))\n", + " plt.show()\n", + " \n", + "def plot_kernels_norm_hist_per_filter(ax, filter, structure, norm_mag):\n", + " norms = get_norms(filter, \"kernels\")\n", + " bins = None\n", + " n, bins, patches = ax.hist(norms[0], alpha=0.5)\n", + " ax.set_xlabel('{:2f} L1-norm'.format(norm_mag))\n", + " \n", + "params_names = conv_param_names(resnet20_dense)\n", + "\n", + "def view_kernel_l1(pname, sort_kernels):\n", + " tensor = resnet20_dense.state_dict()[pname].to('cpu')\n", + " plot_l1_norm_hist(tensor, 'kernels')\n", + " nrows = (tensor.shape[0]+3)//4; ncols = 4\n", + " f, axarr = plt.subplots(nrows, ncols, figsize=(15,7))\n", + " filter_norms = []\n", + " for i in range(0, nrows * ncols):\n", + " filter = tensor[i].unsqueeze(0)\n", + " norm = get_norms(filter, \"filters\")[0][0].item()\n", + " filter_norms.append((norm, i))\n", + " filter_norms.sort(key=lambda norm: norm[0])\n", + " for i in range(0, nrows * ncols):\n", + " filter = tensor[filter_norms[i][1]].unsqueeze(0)\n", + " norm_mag = filter_norms[i][0]\n", + " plot_kernels_norm_hist_per_filter(axarr[i//ncols, i%ncols], filter, 'kernels', norm_mag)\n", + " f.subplots_adjust(hspace=0.6, wspace=0.4)\n", + " plt.show()\n", + "\n", + "sort_choice = widgets.Checkbox(value=True, description='Sort kernels')\n", + "params_dropdown = widgets.Dropdown(description='Weights:', options=params_names, layout=Layout(width='40%'))\n", + "interact(view_kernel_l1, pname=params_dropdown, sort_kernels=sort_choice);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Compare Filter L$_1$ and L$_2$ norms" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "params_names = conv_param_names(resnet20_dense)\n", + "\n", + "def view_weights(pname, sort, draw_l1, draw_l2):\n", + " param = resnet20_dense.state_dict()[pname]\n", + " view_filters = param.view(param.size(0), -1)\n", + " filter_size = view_filters.shape[1]\n", + " filter_mags_l1 = to_np(view_filters.norm(p=1, dim=1).div(filter_size))\n", + " filter_mags_l2 = to_np(view_filters.norm(p=2, dim=1).div(filter_size))\n", + " std_l2, mean_l2 = np.std(filter_mags_l2), np.mean(filter_mags_l2)\n", + " std_l1, mean_l1 = np.std(filter_mags_l1), np.mean(filter_mags_l1)\n", + " if sort:\n", + " filter_mags_l1 = np.sort(filter_mags_l1)\n", + " filter_mags_l2 = np.sort(filter_mags_l2)\n", + " plt.figure(figsize=[15,7.5])\n", + " if draw_l1:\n", + " plt.plot(range(len(filter_mags_l1)), filter_mags_l1, label=\"L1\", marker=\"o\", markersize=5, markerfacecolor=\"C1\")\n", + " if draw_l2:\n", + " plt.plot(range(len(filter_mags_l2)), filter_mags_l2, label=\"L2\", marker=\"+\", markersize=5, markerfacecolor=\"C1\")\n", + " plt.title(\"L1 mean: {:.4f} std: {:.4f}\\nL2 mean: {:.4f} std: {:.4f}\".format(mean_l1, std_l1, mean_l2, std_l2))\n", + " plt.xlabel('Filter index (i.e. output feature-map channel)')\n", + " plt.ylabel('Normalized Fliter L1-norm')\n", + " plt.legend()\n", + "\n", + "sort_choice = widgets.Checkbox(value=True, description='Sort filters')\n", + "l1_choice = widgets.Checkbox(value=True, description='Draw L1')\n", + "l2_choice = widgets.Checkbox(value=True, description='Draw L2')\n", + "params_dropdown = widgets.Dropdown(description='Weights:', options=params_names, layout=Layout(width='40%'))\n", + "interact(view_weights, pname=params_dropdown, sort=sort_choice, draw_l1=l1_choice, draw_l2=l2_choice);\n" ] }, { @@ -553,7 +656,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.5.2" + "version": "3.6.7" } }, "nbformat": 4,