From 842406543ffe28c8a7c3608381af762a0c8538a0 Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Mon, 24 Jun 2019 00:01:59 +0300
Subject: [PATCH] Simulated BN fold module changes

* Support case where BN module has no learnable parameters
  (affine == False)
* Support conv1d and conv3d
---
 distiller/quantization/sim_bn_fold.py | 86 +++++++++++++++++----------
 tests/test_sim_bn_fold.py             | 69 +++++++++++++++++----
 2 files changed, 115 insertions(+), 40 deletions(-)

diff --git a/distiller/quantization/sim_bn_fold.py b/distiller/quantization/sim_bn_fold.py
index 8891324..8c3a6f3 100644
--- a/distiller/quantization/sim_bn_fold.py
+++ b/distiller/quantization/sim_bn_fold.py
@@ -8,6 +8,11 @@ __all__ = ['SimulatedFoldedBatchNorm']
 FREEZE_BN_DELAY_DEFAULT = 200000
 
 
+_conv_meta = {'conv1d': (1, F.conv1d),
+              'conv2d': (2, F.conv2d),
+              'conv3d': (3, F.conv3d)}
+
+
 def _broadcast_correction_factor(c, broadcast_to_shape):
     """
     Returns a view of `c` which is broadcastable with shape `broadcast_to_shape`.
@@ -22,17 +27,16 @@ class SimulatedFoldedBatchNorm(nn.Module):
         """
         Wrapper for simulated folding of BatchNorm into convolution / linear layers during training
         Args:
-            param_module (nn.Linear or nn.Conv2d): the wrapped parameter layer
-            bn (nn.BatchNorm1d or nn.BatchNorm2d): batch normalization
-            freeze_bn_delay (int): number of steps before freezing the batchnorm running stats
+            param_module (nn.Linear or nn.Conv1d or nn.Conv2d or nn.Conv3d): the wrapped parameter module
+            bn (nn.BatchNorm1d or nn.BatchNorm2d or nn.BatchNorm3d): batch normalization module
+            freeze_bn_delay (int): number of steps before freezing the batch-norm running stats
             param_quantization_fn (function): function to be used for weight/bias quantization
         Note:
             The quantized version was implemented according to https://arxiv.org/pdf/1806.08342.pdf Section 3.2.2.
         """
         SimulatedFoldedBatchNorm.verify_module_types(param_module, bn)
-        if not bn.track_running_stats or not bn.affine:
-            raise ValueError("Simulated BN folding is only supported for BatchNorm which tracks running stats with"
-                             "affine weights.")
+        if not bn.track_running_stats:
+            raise ValueError("Simulated BN folding is only supported for BatchNorm which tracks running stats")
         super(SimulatedFoldedBatchNorm, self).__init__()
         self.param_module = param_module
         self.bn = bn
@@ -43,19 +47,30 @@ class SimulatedFoldedBatchNorm(nn.Module):
         if isinstance(param_module, nn.Linear):
             self.param_forward_fn = self._linear_layer_forward
             self.param_module_type = "fc"
-        else:
-            self.param_forward_fn = self._conv2d_layer_forward
+        elif isinstance(param_module, nn.Conv1d):
+            self.param_forward_fn = self._conv_layer_forward
+            self.param_module_type = "conv1d"
+        elif isinstance(param_module, nn.Conv2d):
+            self.param_forward_fn = self._conv_layer_forward
             self.param_module_type = "conv2d"
+        else:
+            self.param_forward_fn = self._conv_layer_forward
+            self.param_module_type = "conv3d"
 
     @staticmethod
     def verify_module_types(param_module, bn):
-        if not isinstance(param_module, (nn.Linear, nn.Conv2d)) \
-                and not isinstance(bn, (nn.BatchNorm1d, nn.BatchNorm2d)):
-            raise TypeError("Only supporting fusing nn.BatchNorm1d/nn.BatchNorm2d into nn.Linear/nn.Conv2d.")
-        if isinstance(param_module, nn.Linear) and isinstance(bn, nn.BatchNorm2d):
-            raise TypeError("nn.Linear layer has to be followed by a nn.BatchNorm1d layer.")
-        if isinstance(param_module, nn.Conv2d) and isinstance(bn, nn.BatchNorm1d):
-            raise TypeError("nn.Con2d layer has to be followed by a nn.BatchNorm2d layer.")
+        foldable_seqs = [((nn.Linear, nn.Conv1d), nn.BatchNorm1d),
+                         (nn.Conv2d, nn.BatchNorm2d),
+                         (nn.Conv3d, nn.BatchNorm3d)]
+        error_msg = "Can't fold sequence of {} --> {}. ".format(param_module.__class__.__name__, bn.__class__.__name__)
+        for seq in foldable_seqs:
+            if isinstance(param_module, seq[0]):
+                if not isinstance(bn, seq[1]):
+                    raise TypeError(error_msg + "{} must be followed by {}".
+                                    format(param_module.__class__.__name__, seq[1].__name__))
+                return
+        raise TypeError(error_msg + "Only Conv/Linear modules followed by BatchNorm modules allowed"
+                        .format(param_module.__class__.__name__, bn.__class__.__name__))
 
     def forward(self, x):
         """
@@ -71,7 +86,7 @@ class SimulatedFoldedBatchNorm(nn.Module):
                           = (x*W -E(x*W)) * gamma / std(x*W) + beta
         """
         if not self.frozen:
-            w, b, gamma, beta = self.param_module.weight, self.param_module.bias, self.bn.weight, self.bn.bias
+            w, b, gamma, beta = self._get_all_parameters()
             if self.training:
                 batch_mean, batch_var = self.batch_stats(self.param_forward_fn(x, w), b)
                 recip_sigma_batch = torch.rsqrt(batch_var + self.bn.eps)
@@ -104,7 +119,7 @@ class SimulatedFoldedBatchNorm(nn.Module):
         """
         Broadcasts a correction factor to the output for elementwise operations.
         """
-        expected_output_dim = 2 if self.param_module_type == "fc" else 4
+        expected_output_dim = 2 if self.param_module_type == "fc" else _conv_meta[self.param_module_type][0] + 2
         view_fillers_dim = expected_output_dim - c.dim() - 1
         view_filler = (1,) * view_fillers_dim
         expected_view_shape = c.shape + view_filler
@@ -116,7 +131,7 @@ class SimulatedFoldedBatchNorm(nn.Module):
         """
         if c.dim() != 1:
             raise ValueError("Correction factor needs to have a single dimension")
-        expected_weight_dim = 2 if self.param_module_type == "fc" else 4
+        expected_weight_dim = 2 if self.param_module_type == "fc" else _conv_meta[self.param_module_type][0] + 2
         view_fillers_dim = expected_weight_dim - c.dim()
         view_filler = (1,) * view_fillers_dim
         expected_view_shape = c.shape + view_filler
@@ -185,20 +200,24 @@ class SimulatedFoldedBatchNorm(nn.Module):
     def _linear_layer_forward(self, input, w, b=None):
         return F.linear(input, w, b)
 
-    def _conv2d_layer_forward(self, input, w, b=None):
-        # We copy the code from the Conv2d forward, but plug in our weights.
-        conv = self.param_module  # type: nn.Conv2d
-        if conv.__dict__.get('padding_mode', None) == 'circular':  # This attribute doesn't exist yet in pytorch 1.0.1
-            expanded_padding = [(conv.padding[1] + 1) // 2, conv.padding[1] // 2,
-                                (conv.padding[0] + 1) // 2, conv.padding[0] // 2]
-            return F.conv2d(F.pad(input, expanded_padding, mode='circular'),
-                            w, b, conv.stride,
-                            (0, 0), conv.dilation, conv.groups)
-        return F.conv2d(input, w, b, conv.stride,
-                        conv.padding, conv.dilation, conv.groups)
+    def _conv_layer_forward(self, input, w, b=None):
+        # We implement according to Conv1/2/3d.forward(), but plug in our weights
+        conv = self.param_module
+        ndims, func = _conv_meta[self.param_module_type]
+
+        # 'circular' padding doesn't exist pre-pytorch 1.1.0
+        if getattr(conv, 'padding_mode', None) == 'circular':
+            expanded_padding = []
+            for pad_idx in reversed(range(ndims)):
+                expanded_padding.extend([(conv.padding[pad_idx] + 1) // 2, conv.padding[pad_idx] // 2])
+            return func(F.pad(input, expanded_padding, mode='circular'),
+                        w, b, conv.stride,
+                        (0,) * ndims, conv.dilation, conv.groups)
+        return func(input, w, b, conv.stride,
+                    conv.padding, conv.dilation, conv.groups)
 
     def freeze(self):
-        w, b, gamma, beta = self.param_module.weight, self.param_module.bias, self.bn.weight, self.bn.bias
+        w, b, gamma, beta = self._get_all_parameters()
         with torch.no_grad():
             recip_sigma_running = torch.rsqrt(self.bn.running_var + self.bn.eps)
             w.mul_(self.broadcast_correction_weight(gamma * recip_sigma_running))
@@ -209,3 +228,10 @@ class SimulatedFoldedBatchNorm(nn.Module):
             else:
                 self.param_module.bias = nn.Parameter(bias_corrected)
         self.frozen = True
+
+    def _get_all_parameters(self):
+        w, b, gamma, beta = self.param_module.weight, self.param_module.bias, self.bn.weight, self.bn.bias
+        if not self.bn.affine:
+            gamma = 1.
+            beta = 0.
+        return w, b, gamma, beta
diff --git a/tests/test_sim_bn_fold.py b/tests/test_sim_bn_fold.py
index a4c9934..99fa241 100644
--- a/tests/test_sim_bn_fold.py
+++ b/tests/test_sim_bn_fold.py
@@ -14,6 +14,22 @@ LR = 1e-3
 DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
+@pytest.mark.parametrize(
+    "m1, m2",
+    [
+        (nn.ReLU(), nn.BatchNorm1d(5)),
+        (nn.Conv1d(1, 2, 3), nn.ReLU()),
+        (nn.Conv1d(1, 2, 3), nn.BatchNorm2d(2)),
+        (nn.Conv2d(1, 2, 3), nn.BatchNorm3d(2)),
+        (nn.Conv3d(1, 2, 3), nn.BatchNorm2d(2)),
+        (nn.Linear(3, 5), nn.BatchNorm2d(5))
+    ]
+)
+def test_simulated_bn_fold_bad_sequences(m1, m2):
+    with pytest.raises(TypeError):
+        SimulatedFoldedBatchNorm(m1, m2)
+
+
 @pytest.fixture(params=[False, True], ids=['bias_off', 'bias_on'])
 def has_bias(request):
     return request.param
@@ -24,6 +40,11 @@ def momentum(request):
     return request.param
 
 
+@pytest.fixture(params=[True, False], ids=['affine_on', 'affine_off'])
+def affine(request):
+    return request.param
+
+
 @pytest.mark.parametrize(
     "batch_size, input_size, output_size",
     [
@@ -32,12 +53,25 @@ def momentum(request):
         (256, 128, 1024)
     ]
 )
-def test_simulated_bn_fold_fc(has_bias, batch_size, input_size, output_size, momentum):
+def test_simulated_bn_fold_fc(has_bias, batch_size, input_size, output_size, momentum, affine):
     distiller.set_deterministic(1234)
     linear = nn.Linear(input_size, output_size, bias=has_bias)
-    bn = nn.BatchNorm1d(output_size, momentum=momentum)
+    bn = nn.BatchNorm1d(output_size, momentum=momentum, affine=affine)
     run_simulated_bn_fold_test(linear, bn, (batch_size, input_size), has_bias)
-        
+
+
+@pytest.mark.parametrize(
+    "batch_size, input_c, output_c, l, kernel_size",
+    [
+        (50, 3, 100, 80, 10),
+    ]
+)
+def test_simulated_bn_fold_conv1d(has_bias, batch_size, input_c, output_c, l, kernel_size, momentum, affine):
+    distiller.set_deterministic(1234)
+    conv1d = nn.Conv1d(input_c, output_c, kernel_size, bias=has_bias)
+    bn = nn.BatchNorm1d(output_c, momentum=momentum, affine=affine)
+    run_simulated_bn_fold_test(conv1d, bn, (batch_size, input_c, l), has_bias)
+
 
 @pytest.mark.parametrize(
     "batch_size, input_c, output_c, h, w, kernel_size",
@@ -47,13 +81,26 @@ def test_simulated_bn_fold_fc(has_bias, batch_size, input_size, output_size, mom
         (256, 3, 64, 28, 28, 7),
     ]
 )
-def test_simulated_bn_fold_conv(has_bias, batch_size, input_c, output_c, h, w, kernel_size, momentum):
+def test_simulated_bn_fold_conv2d(has_bias, batch_size, input_c, output_c, h, w, kernel_size, momentum, affine):
     distiller.set_deterministic(1234)
     conv2d = nn.Conv2d(input_c, output_c, kernel_size, bias=has_bias)
-    bn = nn.BatchNorm2d(output_c, momentum=momentum)
+    bn = nn.BatchNorm2d(output_c, momentum=momentum, affine=affine)
     run_simulated_bn_fold_test(conv2d, bn, (batch_size, input_c, h, w), has_bias)
 
 
+@pytest.mark.parametrize(
+    "batch_size, input_c, output_c, h, w, d, kernel_size",
+    [
+        (2, 2, 3, 64, 64, 9, 3),
+    ]
+)
+def test_simulated_bn_fold_conv3d(has_bias, batch_size, input_c, output_c, h, w, d, kernel_size, momentum, affine):
+    distiller.set_deterministic(1234)
+    conv3d = nn.Conv3d(input_c, output_c, kernel_size, bias=has_bias)
+    bn = nn.BatchNorm3d(output_c, momentum=momentum, affine=affine)
+    run_simulated_bn_fold_test(conv3d, bn, (batch_size, input_c, h, w, d), has_bias)
+
+
 def run_simulated_bn_fold_test(param_layer, bn_layer, x_size, has_bias):
     folded = SimulatedFoldedBatchNorm(deepcopy(param_layer), deepcopy(bn_layer), param_quantization_fn=None)
     unfolded = nn.Sequential(param_layer, bn_layer)
@@ -82,13 +129,14 @@ def run_simulated_bn_fold_test(param_layer, bn_layer, x_size, has_bias):
         loss_unfolded.backward()
 
         # check the gradients:
-        assert_allclose(unfolded[0].weight.grad, folded.param_module.weight.grad)
+        assert_allclose(unfolded[0].weight.grad, folded.param_module.weight.grad, RTOL, ATOL)
         if has_bias:
             # The bias of the linear layer doesn't participate in the calculation!
             # for more details - refer to `FusedLinearBatchNorm.forward`
             assert folded.param_module.bias.grad is None
-        assert_allclose(unfolded[1].weight.grad, folded.bn.weight.grad)
-        assert_allclose(unfolded[1].bias.grad, folded.bn.bias.grad)
+        if bn_layer.affine:
+            assert_allclose(unfolded[1].weight.grad, folded.bn.weight.grad, RTOL, ATOL)
+            assert_allclose(unfolded[1].bias.grad, folded.bn.bias.grad, RTOL, ATOL)
 
         # make a step:
         optimizer_unfolded.step()
@@ -96,8 +144,9 @@ def run_simulated_bn_fold_test(param_layer, bn_layer, x_size, has_bias):
 
         # check updated weights (we skip the linear bias)
         assert_allclose(unfolded[0].weight, folded.param_module.weight, RTOL, ATOL)
-        assert_allclose(unfolded[1].weight, folded.bn.weight, RTOL, ATOL)
-        assert_allclose(unfolded[1].bias, folded.bn.bias, RTOL, ATOL)
+        if bn_layer.affine:
+            assert_allclose(unfolded[1].weight, folded.bn.weight, RTOL, ATOL)
+            assert_allclose(unfolded[1].bias, folded.bn.bias, RTOL, ATOL)
         assert_allclose(unfolded[1].running_mean, folded.bn.running_mean, RTOL, ATOL)
         assert_allclose(unfolded[1].running_var, folded.bn.running_var, RTOL, ATOL)
 
-- 
GitLab