Skip to content
Snippets Groups Projects
Commit 22e3ea8b authored by Neta Zmora's avatar Neta Zmora
Browse files

Fix activation stats for Linear layers

Thanks to Dan Alistarh for bringing this issue to my attention.
The activations of Linear layers have shape (batch_size, output_size) and those
of Convolution layers have shape (batch_size, num_channels, width, height) and
this distinction in shape was not correctly handled.

This commit also fixes sparsity computation for very large activations, as seen
in VGG16, which leads to memory exhaustion.  One solution is to use smaller
batch sizes, but this commit uses a different solution, which counts zeros “manually”,
and using less space.

Also in this commit:
- Added a “caveats” section to the documentation.
- Added more tests.
parent fe9ffb17
No related branches found
No related tags found
No related merge requests found
......@@ -147,14 +147,18 @@ class SummaryActivationStatsCollector(ActivationStatsCollector):
"""
try:
getattr(module, self.stat_name).add(self.summary_fn(output.data))
except RuntimeError:
raise ValueError("ActivationStatsCollector: a module was encountered twice during model.apply().\n"
"This is an indication that your model is using the same module instance, "
"in multiple nodes in the graph. This usually occurs with ReLU modules: \n"
"For example in TorchVision's ResNet model, self.relu = nn.ReLU(inplace=True) is "
"instantiated once, but used multiple times. This is not permissible when using "
"instances of ActivationStatsCollector.")
except RuntimeError as e:
if "The expanded size of the tensor" in e.args[0]:
raise ValueError("ActivationStatsCollector: a module ({} - {}) was encountered twice during model.apply().\n"
"This is an indication that your model is using the same module instance, "
"in multiple nodes in the graph. This usually occurs with ReLU modules: \n"
"For example in TorchVision's ResNet model, self.relu = nn.ReLU(inplace=True) is "
"instantiated once, but used multiple times. This is not permissible when using "
"instances of ActivationStatsCollector.".
format(module.distiller_name, type(module)))
else:
msglogger.info("Exception in _activation_stats_cb: {} {}".format(module.distiller_name, type(module)))
raise
def _start_counter(self, module):
if not hasattr(module, self.stat_name):
......
......@@ -144,10 +144,10 @@ def density(tensor):
Returns:
density (float)
"""
nonzero = torch.nonzero(tensor)
if nonzero.dim() == 0:
return 0.0
return nonzero.size(0) / float(torch.numel(tensor))
# Using torch.nonzero(tensor) can lead to memory exhaustion on
# very large tensors, so we count zeros "manually".
nonzero = tensor.abs().gt(0).sum()
return float(nonzero.item()) / torch.numel(tensor)
def sparsity(tensor):
......@@ -252,14 +252,14 @@ def sparsity_matrix(tensor, dim):
return 1 - nonzero_structs/num_structs
def sparsity_cols(tensor, trasposed=True):
def sparsity_cols(tensor, transposed=True):
"""Column-wise sparsity for 2D tensors
PyTorch GEMM matrices are transposed before they are used in the GEMM operation.
In other words the matrices are stored in memory transposed. So by default we compute
the sparsity of the transposed dimension.
"""
if trasposed:
if transposed:
return sparsity_matrix(tensor, 0)
return sparsity_matrix(tensor, 1)
......@@ -269,14 +269,14 @@ def density_cols(tensor, transposed=True):
return 1 - sparsity_cols(tensor, transposed)
def sparsity_rows(tensor, trasposed=True):
def sparsity_rows(tensor, transposed=True):
"""Row-wise sparsity for 2D matrices
PyTorch GEMM matrices are transposed before they are used in the GEMM operation.
In other words the matrices are stored in memory transposed. So by default we compute
the sparsity of the transposed dimension.
"""
if trasposed:
if transposed:
return sparsity_matrix(tensor, 1)
return sparsity_matrix(tensor, 0)
......@@ -339,9 +339,14 @@ def activation_channels_l1(activation):
Returns - for each channel: the batch-mean of its L1 magnitudes (i.e. over all of the
activations in the mini-batch, compute the mean of the L! magnitude of each channel).
"""
view_2d = activation.view(-1, activation.size(2) * activation.size(3)) # (batch*channel) x (h*w)
featuremap_norms = view_2d.norm(p=1, dim=1)
featuremap_norms_mat = featuremap_norms.view(activation.size(0), activation.size(1)) # batch x channel
if activation.dim() == 4:
view_2d = activation.view(-1, activation.size(2) * activation.size(3)) # (batch*channels) x (h*w)
featuremap_norms = view_2d.norm(p=1, dim=1) # (batch*channels) x 1
featuremap_norms_mat = featuremap_norms.view(activation.size(0), activation.size(1)) # batch x channels
elif activation.dim() == 2:
featuremap_norms_mat = activation.norm(p=1, dim=1) # batch x 1
else:
raise ValueError("activation_channels_l1: Unsupported shape: ".format(activation.shape))
# We need to move the results back to the CPU
return featuremap_norms_mat.mean(dim=0).cpu()
......@@ -357,9 +362,14 @@ def activation_channels_means(activation):
Returns - for each channel: the batch-mean of its L1 magnitudes (i.e. over all of the
activations in the mini-batch, compute the mean of the L1 magnitude of each channel).
"""
view_2d = activation.view(-1, activation.size(2) * activation.size(3)) # (batch*channel) x (h*w)
featuremap_means = sparsity_rows(view_2d)
featuremap_means_mat = featuremap_means.view(activation.size(0), activation.size(1)) # batch x channel
if activation.dim() == 4:
view_2d = activation.view(-1, activation.size(2) * activation.size(3)) # (batch*channels) x (h*w)
featuremap_means = view_2d.mean(dim=1) # (batch*channels) x 1
featuremap_means_mat = featuremap_means.view(activation.size(0), activation.size(1)) # batch x channels
elif activation.dim() == 2:
featuremap_means_mat = activation.mean(dim=1) # batch x 1
else:
raise ValueError("activation_channels_means: Unsupported shape: ".format(activation.shape))
# We need to move the results back to the CPU
return featuremap_means_mat.mean(dim=0).cpu()
......@@ -377,11 +387,15 @@ def activation_channels_apoz(activation):
Returns - for each channel: the batch-mean of its sparsity.
"""
view_2d = activation.view(-1, activation.size(2) * activation.size(3)) # (batch*channel) x (h*w)
featuremap_means = view_2d.mean(dim=1) # global average pooling
featuremap_means_mat = featuremap_means.view(activation.size(0), activation.size(1)) # batch x channel
# We need to move the results back to the CPU
return featuremap_means_mat.mean(dim=0).cpu()
if activation.dim() == 4:
view_2d = activation.view(-1, activation.size(2) * activation.size(3)) # (batch*channels) x (h*w)
featuremap_apoz = view_2d.abs().gt(0).sum(dim=1).float() / (activation.size(2) * activation.size(3)) # (batch*channels) x 1
featuremap_apoz_mat = featuremap_apoz.view(activation.size(0), activation.size(1)) # batch x channels
elif activation.dim() == 2:
featuremap_apoz_mat = activation.abs().gt(0).sum(dim=1).float() / activation.size(1) # batch x 1
else:
raise ValueError("activation_channels_apoz: Unsupported shape: ".format(activation.shape))
return featuremap_apoz_mat.mean(dim=0).cpu()
def log_training_progress(stats_dict, params_dict, epoch, steps_completed, total_steps, log_freq, loggers):
......
......@@ -225,7 +225,7 @@ $ tensorboard --logdir=logs
Distillers's setup (requirements.txt) installs TensorFlow for CPU. If you want a different installation, please follow the [TensorFlow installation instructions](https://www.tensorflow.org/install/install_linux).
## Collecting feature-maps statistics
## Collecting activations statistics
In CNNs with ReLU layers, ReLU activations (feature-maps) also exhibit a nice level of sparsity (50-60% sparsity is typical). <br>
You can collect activation statistics using the ```--act_stats``` command-line flag.<br>
For example:
......@@ -258,6 +258,96 @@ You can use a utility function, ```distiller.log_activation_statsitics```, to lo
distiller.log_activation_statsitics(epoch, "train", loggers=[tflogger],
collector=collectors["sparsity"])
```
### Caveats
Distiller collects activations statistics using PyTorch's forward-hooks mechanism. Collectors iteratively register the modules' forward-hooks, and collectors are called during the forward traversal and get exposed to activation data. Registering for forward callbacks is performed like this:
```
module.register_forward_hook
```
This makes apparent two limitations of this mechanism:
1. We can only register on PyTorch modules. This means that we can't register on the forward hook of a functionals such as ```torch.nn.functional.relu``` and ```torch.nn.functional.max_pool2d```.
Therefore, you may need to replace functionals with their module alternative. For example:
```
class MadeUpNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return x
```
Can be changed to:
```
class MadeUpNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.relu(self.conv1(x))
return x
```
2. We can only use a module instance once in our models. If we use the same module several times, then we can't determine which node in the graph has invoked the callback, because the PyTorch callback signature ```def hook(module, input, output)``` doesn't provide enough contextual information.
TorchVision's [ResNet](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py) is an example of a model that uses the same instance of nn.ReLU multiple times:
```
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out) # <================
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out) # <================
return out
```
In Distiller we changed [ResNet](https://github.com/NervanaSystems/distiller/blob/master/models/imagenet/resnet.py) to use multiple instances of nn.ReLU, and each instance is used only once:
```
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.relu2 = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out) # <================
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu2(out) # <================
return out
```
# Using the Jupyter notebooks
The Jupyter notebooks contain many examples of how to use the statistics summaries generated by Distiller. They are explained in a separate page.
......
......@@ -258,5 +258,5 @@ And of course, if we used a sparse or compressed representation, then we are red
<!--
MkDocs version : 0.17.2
Build Date UTC : 2018-11-21 21:54:00
Build Date UTC : 2018-11-24 09:47:02
-->
This diff is collapsed.
......@@ -4,7 +4,7 @@
<url>
<loc>/index.html</loc>
<lastmod>2018-11-21</lastmod>
<lastmod>2018-11-24</lastmod>
<changefreq>daily</changefreq>
</url>
......@@ -12,7 +12,7 @@
<url>
<loc>/install/index.html</loc>
<lastmod>2018-11-21</lastmod>
<lastmod>2018-11-24</lastmod>
<changefreq>daily</changefreq>
</url>
......@@ -20,7 +20,7 @@
<url>
<loc>/usage/index.html</loc>
<lastmod>2018-11-21</lastmod>
<lastmod>2018-11-24</lastmod>
<changefreq>daily</changefreq>
</url>
......@@ -28,7 +28,7 @@
<url>
<loc>/schedule/index.html</loc>
<lastmod>2018-11-21</lastmod>
<lastmod>2018-11-24</lastmod>
<changefreq>daily</changefreq>
</url>
......@@ -37,31 +37,31 @@
<url>
<loc>/pruning/index.html</loc>
<lastmod>2018-11-21</lastmod>
<lastmod>2018-11-24</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>/regularization/index.html</loc>
<lastmod>2018-11-21</lastmod>
<lastmod>2018-11-24</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>/quantization/index.html</loc>
<lastmod>2018-11-21</lastmod>
<lastmod>2018-11-24</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>/knowledge_distillation/index.html</loc>
<lastmod>2018-11-21</lastmod>
<lastmod>2018-11-24</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>/conditional_computation/index.html</loc>
<lastmod>2018-11-21</lastmod>
<lastmod>2018-11-24</lastmod>
<changefreq>daily</changefreq>
</url>
......@@ -71,19 +71,19 @@
<url>
<loc>/algo_pruning/index.html</loc>
<lastmod>2018-11-21</lastmod>
<lastmod>2018-11-24</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>/algo_quantization/index.html</loc>
<lastmod>2018-11-21</lastmod>
<lastmod>2018-11-24</lastmod>
<changefreq>daily</changefreq>
</url>
<url>
<loc>/algo_earlyexit/index.html</loc>
<lastmod>2018-11-21</lastmod>
<lastmod>2018-11-24</lastmod>
<changefreq>daily</changefreq>
</url>
......@@ -92,7 +92,7 @@
<url>
<loc>/model_zoo/index.html</loc>
<lastmod>2018-11-21</lastmod>
<lastmod>2018-11-24</lastmod>
<changefreq>daily</changefreq>
</url>
......@@ -100,7 +100,7 @@
<url>
<loc>/jupyter/index.html</loc>
<lastmod>2018-11-21</lastmod>
<lastmod>2018-11-24</lastmod>
<changefreq>daily</changefreq>
</url>
......@@ -108,7 +108,7 @@
<url>
<loc>/design/index.html</loc>
<lastmod>2018-11-21</lastmod>
<lastmod>2018-11-24</lastmod>
<changefreq>daily</changefreq>
</url>
......
......@@ -81,7 +81,7 @@
<li><a class="toctree-l3" href="#using-tensorboard">Using TensorBoard</a></li>
<li><a class="toctree-l3" href="#collecting-feature-maps-statistics">Collecting feature-maps statistics</a></li>
<li><a class="toctree-l3" href="#collecting-activations-statistics">Collecting activations statistics</a></li>
</ul>
......@@ -412,7 +412,7 @@ To view the graphs, invoke the TensorBoard server. For example:</p>
</code></pre>
<p>Distillers's setup (requirements.txt) installs TensorFlow for CPU. If you want a different installation, please follow the <a href="https://www.tensorflow.org/install/install_linux">TensorFlow installation instructions</a>.</p>
<h2 id="collecting-feature-maps-statistics">Collecting feature-maps statistics</h2>
<h2 id="collecting-activations-statistics">Collecting activations statistics</h2>
<p>In CNNs with ReLU layers, ReLU activations (feature-maps) also exhibit a nice level of sparsity (50-60% sparsity is typical). <br>
You can collect activation statistics using the <code>--act_stats</code> command-line flag.<br>
For example:</p>
......@@ -443,6 +443,96 @@ You can use a utility function, <code>distiller.log_activation_statsitics</code>
collector=collectors[&quot;sparsity&quot;])
</code></pre>
<h3 id="caveats">Caveats</h3>
<p>Distiller collects activations statistics using PyTorch's forward-hooks mechanism. Collectors iteratively register the modules' forward-hooks, and collectors are called during the forward traversal and get exposed to activation data. Registering for forward callbacks is performed like this:</p>
<pre><code>module.register_forward_hook
</code></pre>
<p>This makes apparent two limitations of this mechanism:</p>
<ol>
<li>We can only register on PyTorch modules. This means that we can't register on the forward hook of a functionals such as <code>torch.nn.functional.relu</code> and <code>torch.nn.functional.max_pool2d</code>.<br />
Therefore, you may need to replace functionals with their module alternative. For example: </li>
</ol>
<pre><code>class MadeUpNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return x
</code></pre>
<p>Can be changed to: </p>
<pre><code>class MadeUpNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.relu(self.conv1(x))
return x
</code></pre>
<ol>
<li>We can only use a module instance once in our models. If we use the same module several times, then we can't determine which node in the graph has invoked the callback, because the PyTorch callback signature <code>def hook(module, input, output)</code> doesn't provide enough contextual information.<br />
TorchVision's <a href="https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py">ResNet</a> is an example of a model that uses the same instance of nn.ReLU multiple times: </li>
</ol>
<pre><code>class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out) # &lt;================
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out) # &lt;================
return out
</code></pre>
<p>In Distiller we changed <a href="https://github.com/NervanaSystems/distiller/blob/master/models/imagenet/resnet.py">ResNet</a> to use multiple instances of nn.ReLU, and each instance is used only once: </p>
<pre><code>class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.relu2 = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out) # &lt;================
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu2(out) # &lt;================
return out
</code></pre>
<h1 id="using-the-jupyter-notebooks">Using the Jupyter notebooks</h1>
<p>The Jupyter notebooks contain many examples of how to use the statistics summaries generated by Distiller. They are explained in a separate page.</p>
<h1 id="generating-this-documentation">Generating this documentation</h1>
......
......@@ -17,23 +17,58 @@
import torch
import os
import sys
import common
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
sys.path.append(module_path)
import distiller
import models
def test_sparsity():
zeros = torch.zeros(2,3,5,6)
zeros = torch.zeros(2, 3, 5, 6)
print(distiller.sparsity(zeros))
assert distiller.sparsity(zeros) == 1.0
assert distiller.sparsity_3D(zeros) == 1.0
assert distiller.density_3D(zeros) == 0.0
ones = torch.ones(12, 43, 4, 6)
assert distiller.sparsity(ones) == 0.0
x = torch.tensor([[1., 2., 0, 4., 0],
[1., 2., 0, 4., 0]])
assert distiller.density(x) == 0.6
assert distiller.density_cols(x, transposed=False) == 0.6
assert distiller.sparsity_rows(x, transposed=False) == 0
x = torch.tensor([[0., 0., 0],
[1., 4., 0],
[1., 2., 0],
[0., 0., 0]])
assert distiller.density(x) == 4/12
assert distiller.sparsity_rows(x, transposed=False) == 0.5
assert common.almost_equal(distiller.sparsity_cols(x, transposed=False), 1/3)
assert common.almost_equal(distiller.sparsity_rows(x), 1/3)
ones = torch.zeros(12,43,4,6)
ones.fill_(1)
assert distiller.sparsity(ones) == 0.0
def test_activations():
x = torch.tensor([[[[1., 0., 0.],
[0., 2., 0.],
[0., 0., 3.]],
[[1., 0., 2.],
[0., 3., 0.],
[4., 0., 5.]]],
[[[4., 0., 0.],
[0., 5., 0.],
[0., 0., 6.]],
[[0., 6., 0.],
[7., 0., 8.],
[0., 9., 0.]]]])
assert all(distiller.activation_channels_l1(x) == torch.tensor([21/2, 45/2]))
assert all(distiller.activation_channels_apoz(x) == torch.tensor([6/18, 9/18]))
assert all(distiller.activation_channels_means(x) == torch.tensor([21/18, 45/18]))
def test_utils():
model = models.create_model(False, 'cifar10', 'resnet20_cifar', parallel=False)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment