From 4c7d4890eb27c56bb6061ed4c859f48b679b0e49 Mon Sep 17 00:00:00 2001
From: Lev Zlotnik <46742999+levzlotnik@users.noreply.github.com>
Date: Sun, 23 Jun 2019 13:01:08 +0300
Subject: [PATCH] Simulated BN folding during training (module impl only)
 (#274)

Implementation of simulated BatchNorm folding as per
https://arxiv.org/pdf/1806.08342.pdf

* NOTE: This is just the module implementation for now, not yet integrated
  with QuantAwareTrainRangeLinearQuantizer
---
 distiller/models/imagenet/mobilenet.py |   0
 distiller/quantization/sim_bn_fold.py  | 211 +++++++++++++++++++++++++
 tests/test_sim_bn_fold.py              | 113 +++++++++++++
 3 files changed, 324 insertions(+)
 mode change 100755 => 100644 distiller/models/imagenet/mobilenet.py
 create mode 100644 distiller/quantization/sim_bn_fold.py
 create mode 100644 tests/test_sim_bn_fold.py

diff --git a/distiller/models/imagenet/mobilenet.py b/distiller/models/imagenet/mobilenet.py
old mode 100755
new mode 100644
diff --git a/distiller/quantization/sim_bn_fold.py b/distiller/quantization/sim_bn_fold.py
new file mode 100644
index 0000000..8891324
--- /dev/null
+++ b/distiller/quantization/sim_bn_fold.py
@@ -0,0 +1,211 @@
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+__all__ = ['SimulatedFoldedBatchNorm']
+
+# Number of steps before freezing the batch norm running average and variance
+FREEZE_BN_DELAY_DEFAULT = 200000
+
+
+def _broadcast_correction_factor(c, broadcast_to_shape):
+    """
+    Returns a view of `c` which is broadcastable with shape `broadcast_to_shape`.
+    """
+    filler_dims = (1,) * (len(broadcast_to_shape) - len(c.shape) - 1)
+    view_dims = (*c.shape, *filler_dims)
+    return c.view(view_dims)
+
+
+class SimulatedFoldedBatchNorm(nn.Module):
+    def __init__(self, param_module, bn, freeze_bn_delay=FREEZE_BN_DELAY_DEFAULT, param_quantization_fn=None):
+        """
+        Wrapper for simulated folding of BatchNorm into convolution / linear layers during training
+        Args:
+            param_module (nn.Linear or nn.Conv2d): the wrapped parameter layer
+            bn (nn.BatchNorm1d or nn.BatchNorm2d): batch normalization
+            freeze_bn_delay (int): number of steps before freezing the batchnorm running stats
+            param_quantization_fn (function): function to be used for weight/bias quantization
+        Note:
+            The quantized version was implemented according to https://arxiv.org/pdf/1806.08342.pdf Section 3.2.2.
+        """
+        SimulatedFoldedBatchNorm.verify_module_types(param_module, bn)
+        if not bn.track_running_stats or not bn.affine:
+            raise ValueError("Simulated BN folding is only supported for BatchNorm which tracks running stats with"
+                             "affine weights.")
+        super(SimulatedFoldedBatchNorm, self).__init__()
+        self.param_module = param_module
+        self.bn = bn
+        self.freeze_bn_delay = freeze_bn_delay
+        self.frozen = False
+        self._has_bias = (self.param_module.bias is not None)
+        self.param_quant_fn = param_quantization_fn
+        if isinstance(param_module, nn.Linear):
+            self.param_forward_fn = self._linear_layer_forward
+            self.param_module_type = "fc"
+        else:
+            self.param_forward_fn = self._conv2d_layer_forward
+            self.param_module_type = "conv2d"
+
+    @staticmethod
+    def verify_module_types(param_module, bn):
+        if not isinstance(param_module, (nn.Linear, nn.Conv2d)) \
+                and not isinstance(bn, (nn.BatchNorm1d, nn.BatchNorm2d)):
+            raise TypeError("Only supporting fusing nn.BatchNorm1d/nn.BatchNorm2d into nn.Linear/nn.Conv2d.")
+        if isinstance(param_module, nn.Linear) and isinstance(bn, nn.BatchNorm2d):
+            raise TypeError("nn.Linear layer has to be followed by a nn.BatchNorm1d layer.")
+        if isinstance(param_module, nn.Conv2d) and isinstance(bn, nn.BatchNorm1d):
+            raise TypeError("nn.Con2d layer has to be followed by a nn.BatchNorm2d layer.")
+
+    def forward(self, x):
+        """
+        According to https://arxiv.org/pdf/1806.08342.pdf section 3.2.2.
+        Note:
+            The param layer bias doesn't get included in the calculation!
+            When calculating the batch norm,
+            the bias offsets the mean and so when calculating (x - mu) we get the unbiased position
+            w.r.t. to the mean.
+            i.e. the result of the forward is:
+            bn(param(x)) = ( param(x) - E(param(x)) ) * gamma / std(param(x)) + beta =
+                          = ( x*W + B - E(x*W +B) ) * gamma / sqrt(E((x*W+ B - E(x*W +B))^2)) + beta =
+                          = (x*W -E(x*W)) * gamma / std(x*W) + beta
+        """
+        if not self.frozen:
+            w, b, gamma, beta = self.param_module.weight, self.param_module.bias, self.bn.weight, self.bn.bias
+            if self.training:
+                batch_mean, batch_var = self.batch_stats(self.param_forward_fn(x, w), b)
+                recip_sigma_batch = torch.rsqrt(batch_var + self.bn.eps)
+                with torch.no_grad():
+                    sigma_running = torch.sqrt(self.bn.running_var + self.bn.eps)
+                w_corrected = w * self.broadcast_correction_weight(gamma / sigma_running)
+                w_quantized = self._quant_param(w_corrected)
+                recip_c = self.broadcast_correction(sigma_running * recip_sigma_batch)
+                bias_corrected = beta - gamma * batch_mean * recip_sigma_batch
+                bias_quantized = self.broadcast_correction(self._quant_param(bias_corrected))
+                y = self.param_forward_fn(x, w_quantized, None)
+                y.mul_(recip_c).add_(bias_quantized)
+            else:
+                with torch.no_grad():
+                    recip_sigma_running = torch.rsqrt(self.bn.running_var + self.bn.eps)
+                w_corrected = w * self.broadcast_correction_weight(gamma * recip_sigma_running)
+                w_quantized = self._quant_param(w_corrected)
+                corrected_mean = self.bn.running_mean - (b if b is not None else 0)
+                bias_corrected = beta - gamma * corrected_mean * recip_sigma_running
+                bias_quantized = self._quant_param(bias_corrected)
+                y = self.param_forward_fn(x, w_quantized, bias_quantized)
+        else:
+            w, b = self.param_module.weight, self.param_module.bias
+            w_quantized, bias_quantized = self._quant_param(w), self._quant_param(b)
+            y = self.param_forward_fn(x, w_quantized, bias_quantized)
+
+        return y
+
+    def broadcast_correction(self, c: torch.Tensor):
+        """
+        Broadcasts a correction factor to the output for elementwise operations.
+        """
+        expected_output_dim = 2 if self.param_module_type == "fc" else 4
+        view_fillers_dim = expected_output_dim - c.dim() - 1
+        view_filler = (1,) * view_fillers_dim
+        expected_view_shape = c.shape + view_filler
+        return c.view(*expected_view_shape)
+
+    def broadcast_correction_weight(self, c: torch.Tensor):
+        """
+        Broadcasts a correction factor to the weight.
+        """
+        if c.dim() != 1:
+            raise ValueError("Correction factor needs to have a single dimension")
+        expected_weight_dim = 2 if self.param_module_type == "fc" else 4
+        view_fillers_dim = expected_weight_dim - c.dim()
+        view_filler = (1,) * view_fillers_dim
+        expected_view_shape = c.shape + view_filler
+        return c.view(*expected_view_shape)
+
+    def _quant_param(self, t: torch.Tensor):
+        """
+        Quantize a parameter locally.
+        """
+        if t is None or self.param_quant_fn is None:
+            return t
+        return self.param_quant_fn(t)
+
+    def batch_stats(self, x, bias=None):
+        """
+        Get the batch mean and variance of x and updates the BatchNorm's running mean and average.
+        Args:
+            x (torch.Tensor): input batch.
+            bias (torch.Tensor): the bias that is to be applied to the batch.
+        Returns:
+            (mean,variance)
+        Note:
+            In case of `nn.Linear`, x may be of shape (N, C, L) or (N, L)
+            where N is batch size, C is number of channels, L is the features size.
+            The batch norm computes the stats over C in the first case or L on the second case.
+            The batch normalization layer is
+            (`nn.BatchNorm1d`)[https://pytorch.org/docs/stable/nn.html#batchnorm1d]
+
+            In case of `nn.Conv2d`, x is of shape (N, C, H, W)
+            where H,W are the image dimensions, and the batch norm computes the stats over C.
+            The batch normalization layer is
+            (`nn.BatchNorm2d`)[https://pytorch.org/docs/stable/nn.html#batchnorm2d]
+        """
+        channel_size = self.bn.num_features
+        self.bn.num_batches_tracked += 1
+
+        # Calculate current batch stats
+        batch_mean = x.transpose(0, 1).contiguous().view(channel_size, -1).mean(1)
+        # BatchNorm currently uses biased variance (without Bessel's correction) as was discussed at
+        # https://github.com/pytorch/pytorch/issues/1410
+        #
+        # also see the source code itself:
+        # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L216
+        batch_var = x.transpose(0, 1).contiguous().view(channel_size, -1).var(1, unbiased=False)
+
+        # Update running stats
+        with torch.no_grad():
+            biased_batch_mean = batch_mean + (bias if bias is not None else 0)
+            # However - running_var is updated using unbiased variance!
+            # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Normalization.cpp#L223
+            n = x.numel() / channel_size
+            corrected_var = batch_var * (n / (n - 1))
+            momentum = self.bn.momentum
+            if momentum is None:
+                # momentum is None - we compute a cumulative moving average
+                # as noted in https://pytorch.org/docs/stable/nn.html#batchnorm2d
+                momentum = 1. / float(self.bn.num_batches_tracked)
+            self.bn.running_mean.mul_(1 - momentum).add_(momentum * biased_batch_mean)
+            self.bn.running_var.mul_(1 - momentum).add_(momentum * corrected_var)
+
+        if self.bn.num_batches_tracked > self.freeze_bn_delay:
+            self.freeze()
+
+        return batch_mean, batch_var
+
+    def _linear_layer_forward(self, input, w, b=None):
+        return F.linear(input, w, b)
+
+    def _conv2d_layer_forward(self, input, w, b=None):
+        # We copy the code from the Conv2d forward, but plug in our weights.
+        conv = self.param_module  # type: nn.Conv2d
+        if conv.__dict__.get('padding_mode', None) == 'circular':  # This attribute doesn't exist yet in pytorch 1.0.1
+            expanded_padding = [(conv.padding[1] + 1) // 2, conv.padding[1] // 2,
+                                (conv.padding[0] + 1) // 2, conv.padding[0] // 2]
+            return F.conv2d(F.pad(input, expanded_padding, mode='circular'),
+                            w, b, conv.stride,
+                            (0, 0), conv.dilation, conv.groups)
+        return F.conv2d(input, w, b, conv.stride,
+                        conv.padding, conv.dilation, conv.groups)
+
+    def freeze(self):
+        w, b, gamma, beta = self.param_module.weight, self.param_module.bias, self.bn.weight, self.bn.bias
+        with torch.no_grad():
+            recip_sigma_running = torch.rsqrt(self.bn.running_var + self.bn.eps)
+            w.mul_(self.broadcast_correction_weight(gamma * recip_sigma_running))
+            corrected_mean = self.bn.running_mean - (b if b is not None else 0)
+            bias_corrected = beta - gamma * corrected_mean * recip_sigma_running
+            if b is not None:
+                b.copy_(bias_corrected)
+            else:
+                self.param_module.bias = nn.Parameter(bias_corrected)
+        self.frozen = True
diff --git a/tests/test_sim_bn_fold.py b/tests/test_sim_bn_fold.py
new file mode 100644
index 0000000..a4c9934
--- /dev/null
+++ b/tests/test_sim_bn_fold.py
@@ -0,0 +1,113 @@
+import distiller
+import torch
+from torch.testing import assert_allclose
+import torch.nn as nn
+from distiller.quantization.sim_bn_fold import SimulatedFoldedBatchNorm
+from copy import deepcopy
+import pytest
+
+ATOL = 5e-5
+RTOL = 1e-3
+BATCH_SIZE = 32
+LR = 1e-3
+
+DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+
+@pytest.fixture(params=[False, True], ids=['bias_off', 'bias_on'])
+def has_bias(request):
+    return request.param
+
+
+@pytest.fixture(params=[0.1, None], ids=['ema', 'cma'])
+def momentum(request):
+    return request.param
+
+
+@pytest.mark.parametrize(
+    "batch_size, input_size, output_size",
+    [
+        (2, 2, 3),
+        (32, 512, 1024),
+        (256, 128, 1024)
+    ]
+)
+def test_simulated_bn_fold_fc(has_bias, batch_size, input_size, output_size, momentum):
+    distiller.set_deterministic(1234)
+    linear = nn.Linear(input_size, output_size, bias=has_bias)
+    bn = nn.BatchNorm1d(output_size, momentum=momentum)
+    run_simulated_bn_fold_test(linear, bn, (batch_size, input_size), has_bias)
+        
+
+@pytest.mark.parametrize(
+    "batch_size, input_c, output_c, h, w, kernel_size",
+    [
+        (2, 2, 3, 224, 224, 3),
+        (32, 3, 64, 224, 224, 3),
+        (256, 3, 64, 28, 28, 7),
+    ]
+)
+def test_simulated_bn_fold_conv(has_bias, batch_size, input_c, output_c, h, w, kernel_size, momentum):
+    distiller.set_deterministic(1234)
+    conv2d = nn.Conv2d(input_c, output_c, kernel_size, bias=has_bias)
+    bn = nn.BatchNorm2d(output_c, momentum=momentum)
+    run_simulated_bn_fold_test(conv2d, bn, (batch_size, input_c, h, w), has_bias)
+
+
+def run_simulated_bn_fold_test(param_layer, bn_layer, x_size, has_bias):
+    folded = SimulatedFoldedBatchNorm(deepcopy(param_layer), deepcopy(bn_layer), param_quantization_fn=None)
+    unfolded = nn.Sequential(param_layer, bn_layer)
+    folded, unfolded = folded.to(DEVICE), unfolded.to(DEVICE)
+    optimizer_folded = torch.optim.SGD(folded.parameters(), LR)
+    optimizer_unfolded = torch.optim.SGD(unfolded.parameters(), LR)
+    criterion = nn.MSELoss().to(DEVICE)
+
+    # Test for 10 "epochs" (train + eval)
+    for _ in range(10):
+        folded.train()
+        unfolded.train()
+
+        # inputs and targets:
+        x = torch.rand(x_size, device=DEVICE)
+        y_t = torch.rand_like(param_layer(x))
+
+        # calc loss:
+        optimizer_folded.zero_grad()
+        optimizer_unfolded.zero_grad()
+        loss_folded = criterion(folded(x), y_t)
+        loss_unfolded = criterion(unfolded(x), y_t)
+
+        # calc gradients:
+        loss_folded.backward()
+        loss_unfolded.backward()
+
+        # check the gradients:
+        assert_allclose(unfolded[0].weight.grad, folded.param_module.weight.grad)
+        if has_bias:
+            # The bias of the linear layer doesn't participate in the calculation!
+            # for more details - refer to `FusedLinearBatchNorm.forward`
+            assert folded.param_module.bias.grad is None
+        assert_allclose(unfolded[1].weight.grad, folded.bn.weight.grad)
+        assert_allclose(unfolded[1].bias.grad, folded.bn.bias.grad)
+
+        # make a step:
+        optimizer_unfolded.step()
+        optimizer_folded.step()
+
+        # check updated weights (we skip the linear bias)
+        assert_allclose(unfolded[0].weight, folded.param_module.weight, RTOL, ATOL)
+        assert_allclose(unfolded[1].weight, folded.bn.weight, RTOL, ATOL)
+        assert_allclose(unfolded[1].bias, folded.bn.bias, RTOL, ATOL)
+        assert_allclose(unfolded[1].running_mean, folded.bn.running_mean, RTOL, ATOL)
+        assert_allclose(unfolded[1].running_var, folded.bn.running_var, RTOL, ATOL)
+
+        # testing evaluation:
+        folded.eval()
+        unfolded.eval()
+        x = torch.rand(x_size, device=DEVICE)
+        assert_allclose(unfolded(x), folded(x), RTOL, ATOL)
+
+    # test eval after freezing
+    folded.freeze()
+    x = torch.rand(x_size, device=DEVICE)
+    assert_allclose(unfolded(x), folded(x), RTOL, ATOL)
-- 
GitLab