From 842406543ffe28c8a7c3608381af762a0c8538a0 Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Mon, 24 Jun 2019 00:01:59 +0300 Subject: [PATCH] Simulated BN fold module changes * Support case where BN module has no learnable parameters (affine == False) * Support conv1d and conv3d --- distiller/quantization/sim_bn_fold.py | 86 +++++++++++++++++---------- tests/test_sim_bn_fold.py | 69 +++++++++++++++++---- 2 files changed, 115 insertions(+), 40 deletions(-) diff --git a/distiller/quantization/sim_bn_fold.py b/distiller/quantization/sim_bn_fold.py index 8891324..8c3a6f3 100644 --- a/distiller/quantization/sim_bn_fold.py +++ b/distiller/quantization/sim_bn_fold.py @@ -8,6 +8,11 @@ __all__ = ['SimulatedFoldedBatchNorm'] FREEZE_BN_DELAY_DEFAULT = 200000 +_conv_meta = {'conv1d': (1, F.conv1d), + 'conv2d': (2, F.conv2d), + 'conv3d': (3, F.conv3d)} + + def _broadcast_correction_factor(c, broadcast_to_shape): """ Returns a view of `c` which is broadcastable with shape `broadcast_to_shape`. @@ -22,17 +27,16 @@ class SimulatedFoldedBatchNorm(nn.Module): """ Wrapper for simulated folding of BatchNorm into convolution / linear layers during training Args: - param_module (nn.Linear or nn.Conv2d): the wrapped parameter layer - bn (nn.BatchNorm1d or nn.BatchNorm2d): batch normalization - freeze_bn_delay (int): number of steps before freezing the batchnorm running stats + param_module (nn.Linear or nn.Conv1d or nn.Conv2d or nn.Conv3d): the wrapped parameter module + bn (nn.BatchNorm1d or nn.BatchNorm2d or nn.BatchNorm3d): batch normalization module + freeze_bn_delay (int): number of steps before freezing the batch-norm running stats param_quantization_fn (function): function to be used for weight/bias quantization Note: The quantized version was implemented according to https://arxiv.org/pdf/1806.08342.pdf Section 3.2.2. """ SimulatedFoldedBatchNorm.verify_module_types(param_module, bn) - if not bn.track_running_stats or not bn.affine: - raise ValueError("Simulated BN folding is only supported for BatchNorm which tracks running stats with" - "affine weights.") + if not bn.track_running_stats: + raise ValueError("Simulated BN folding is only supported for BatchNorm which tracks running stats") super(SimulatedFoldedBatchNorm, self).__init__() self.param_module = param_module self.bn = bn @@ -43,19 +47,30 @@ class SimulatedFoldedBatchNorm(nn.Module): if isinstance(param_module, nn.Linear): self.param_forward_fn = self._linear_layer_forward self.param_module_type = "fc" - else: - self.param_forward_fn = self._conv2d_layer_forward + elif isinstance(param_module, nn.Conv1d): + self.param_forward_fn = self._conv_layer_forward + self.param_module_type = "conv1d" + elif isinstance(param_module, nn.Conv2d): + self.param_forward_fn = self._conv_layer_forward self.param_module_type = "conv2d" + else: + self.param_forward_fn = self._conv_layer_forward + self.param_module_type = "conv3d" @staticmethod def verify_module_types(param_module, bn): - if not isinstance(param_module, (nn.Linear, nn.Conv2d)) \ - and not isinstance(bn, (nn.BatchNorm1d, nn.BatchNorm2d)): - raise TypeError("Only supporting fusing nn.BatchNorm1d/nn.BatchNorm2d into nn.Linear/nn.Conv2d.") - if isinstance(param_module, nn.Linear) and isinstance(bn, nn.BatchNorm2d): - raise TypeError("nn.Linear layer has to be followed by a nn.BatchNorm1d layer.") - if isinstance(param_module, nn.Conv2d) and isinstance(bn, nn.BatchNorm1d): - raise TypeError("nn.Con2d layer has to be followed by a nn.BatchNorm2d layer.") + foldable_seqs = [((nn.Linear, nn.Conv1d), nn.BatchNorm1d), + (nn.Conv2d, nn.BatchNorm2d), + (nn.Conv3d, nn.BatchNorm3d)] + error_msg = "Can't fold sequence of {} --> {}. ".format(param_module.__class__.__name__, bn.__class__.__name__) + for seq in foldable_seqs: + if isinstance(param_module, seq[0]): + if not isinstance(bn, seq[1]): + raise TypeError(error_msg + "{} must be followed by {}". + format(param_module.__class__.__name__, seq[1].__name__)) + return + raise TypeError(error_msg + "Only Conv/Linear modules followed by BatchNorm modules allowed" + .format(param_module.__class__.__name__, bn.__class__.__name__)) def forward(self, x): """ @@ -71,7 +86,7 @@ class SimulatedFoldedBatchNorm(nn.Module): = (x*W -E(x*W)) * gamma / std(x*W) + beta """ if not self.frozen: - w, b, gamma, beta = self.param_module.weight, self.param_module.bias, self.bn.weight, self.bn.bias + w, b, gamma, beta = self._get_all_parameters() if self.training: batch_mean, batch_var = self.batch_stats(self.param_forward_fn(x, w), b) recip_sigma_batch = torch.rsqrt(batch_var + self.bn.eps) @@ -104,7 +119,7 @@ class SimulatedFoldedBatchNorm(nn.Module): """ Broadcasts a correction factor to the output for elementwise operations. """ - expected_output_dim = 2 if self.param_module_type == "fc" else 4 + expected_output_dim = 2 if self.param_module_type == "fc" else _conv_meta[self.param_module_type][0] + 2 view_fillers_dim = expected_output_dim - c.dim() - 1 view_filler = (1,) * view_fillers_dim expected_view_shape = c.shape + view_filler @@ -116,7 +131,7 @@ class SimulatedFoldedBatchNorm(nn.Module): """ if c.dim() != 1: raise ValueError("Correction factor needs to have a single dimension") - expected_weight_dim = 2 if self.param_module_type == "fc" else 4 + expected_weight_dim = 2 if self.param_module_type == "fc" else _conv_meta[self.param_module_type][0] + 2 view_fillers_dim = expected_weight_dim - c.dim() view_filler = (1,) * view_fillers_dim expected_view_shape = c.shape + view_filler @@ -185,20 +200,24 @@ class SimulatedFoldedBatchNorm(nn.Module): def _linear_layer_forward(self, input, w, b=None): return F.linear(input, w, b) - def _conv2d_layer_forward(self, input, w, b=None): - # We copy the code from the Conv2d forward, but plug in our weights. - conv = self.param_module # type: nn.Conv2d - if conv.__dict__.get('padding_mode', None) == 'circular': # This attribute doesn't exist yet in pytorch 1.0.1 - expanded_padding = [(conv.padding[1] + 1) // 2, conv.padding[1] // 2, - (conv.padding[0] + 1) // 2, conv.padding[0] // 2] - return F.conv2d(F.pad(input, expanded_padding, mode='circular'), - w, b, conv.stride, - (0, 0), conv.dilation, conv.groups) - return F.conv2d(input, w, b, conv.stride, - conv.padding, conv.dilation, conv.groups) + def _conv_layer_forward(self, input, w, b=None): + # We implement according to Conv1/2/3d.forward(), but plug in our weights + conv = self.param_module + ndims, func = _conv_meta[self.param_module_type] + + # 'circular' padding doesn't exist pre-pytorch 1.1.0 + if getattr(conv, 'padding_mode', None) == 'circular': + expanded_padding = [] + for pad_idx in reversed(range(ndims)): + expanded_padding.extend([(conv.padding[pad_idx] + 1) // 2, conv.padding[pad_idx] // 2]) + return func(F.pad(input, expanded_padding, mode='circular'), + w, b, conv.stride, + (0,) * ndims, conv.dilation, conv.groups) + return func(input, w, b, conv.stride, + conv.padding, conv.dilation, conv.groups) def freeze(self): - w, b, gamma, beta = self.param_module.weight, self.param_module.bias, self.bn.weight, self.bn.bias + w, b, gamma, beta = self._get_all_parameters() with torch.no_grad(): recip_sigma_running = torch.rsqrt(self.bn.running_var + self.bn.eps) w.mul_(self.broadcast_correction_weight(gamma * recip_sigma_running)) @@ -209,3 +228,10 @@ class SimulatedFoldedBatchNorm(nn.Module): else: self.param_module.bias = nn.Parameter(bias_corrected) self.frozen = True + + def _get_all_parameters(self): + w, b, gamma, beta = self.param_module.weight, self.param_module.bias, self.bn.weight, self.bn.bias + if not self.bn.affine: + gamma = 1. + beta = 0. + return w, b, gamma, beta diff --git a/tests/test_sim_bn_fold.py b/tests/test_sim_bn_fold.py index a4c9934..99fa241 100644 --- a/tests/test_sim_bn_fold.py +++ b/tests/test_sim_bn_fold.py @@ -14,6 +14,22 @@ LR = 1e-3 DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' +@pytest.mark.parametrize( + "m1, m2", + [ + (nn.ReLU(), nn.BatchNorm1d(5)), + (nn.Conv1d(1, 2, 3), nn.ReLU()), + (nn.Conv1d(1, 2, 3), nn.BatchNorm2d(2)), + (nn.Conv2d(1, 2, 3), nn.BatchNorm3d(2)), + (nn.Conv3d(1, 2, 3), nn.BatchNorm2d(2)), + (nn.Linear(3, 5), nn.BatchNorm2d(5)) + ] +) +def test_simulated_bn_fold_bad_sequences(m1, m2): + with pytest.raises(TypeError): + SimulatedFoldedBatchNorm(m1, m2) + + @pytest.fixture(params=[False, True], ids=['bias_off', 'bias_on']) def has_bias(request): return request.param @@ -24,6 +40,11 @@ def momentum(request): return request.param +@pytest.fixture(params=[True, False], ids=['affine_on', 'affine_off']) +def affine(request): + return request.param + + @pytest.mark.parametrize( "batch_size, input_size, output_size", [ @@ -32,12 +53,25 @@ def momentum(request): (256, 128, 1024) ] ) -def test_simulated_bn_fold_fc(has_bias, batch_size, input_size, output_size, momentum): +def test_simulated_bn_fold_fc(has_bias, batch_size, input_size, output_size, momentum, affine): distiller.set_deterministic(1234) linear = nn.Linear(input_size, output_size, bias=has_bias) - bn = nn.BatchNorm1d(output_size, momentum=momentum) + bn = nn.BatchNorm1d(output_size, momentum=momentum, affine=affine) run_simulated_bn_fold_test(linear, bn, (batch_size, input_size), has_bias) - + + +@pytest.mark.parametrize( + "batch_size, input_c, output_c, l, kernel_size", + [ + (50, 3, 100, 80, 10), + ] +) +def test_simulated_bn_fold_conv1d(has_bias, batch_size, input_c, output_c, l, kernel_size, momentum, affine): + distiller.set_deterministic(1234) + conv1d = nn.Conv1d(input_c, output_c, kernel_size, bias=has_bias) + bn = nn.BatchNorm1d(output_c, momentum=momentum, affine=affine) + run_simulated_bn_fold_test(conv1d, bn, (batch_size, input_c, l), has_bias) + @pytest.mark.parametrize( "batch_size, input_c, output_c, h, w, kernel_size", @@ -47,13 +81,26 @@ def test_simulated_bn_fold_fc(has_bias, batch_size, input_size, output_size, mom (256, 3, 64, 28, 28, 7), ] ) -def test_simulated_bn_fold_conv(has_bias, batch_size, input_c, output_c, h, w, kernel_size, momentum): +def test_simulated_bn_fold_conv2d(has_bias, batch_size, input_c, output_c, h, w, kernel_size, momentum, affine): distiller.set_deterministic(1234) conv2d = nn.Conv2d(input_c, output_c, kernel_size, bias=has_bias) - bn = nn.BatchNorm2d(output_c, momentum=momentum) + bn = nn.BatchNorm2d(output_c, momentum=momentum, affine=affine) run_simulated_bn_fold_test(conv2d, bn, (batch_size, input_c, h, w), has_bias) +@pytest.mark.parametrize( + "batch_size, input_c, output_c, h, w, d, kernel_size", + [ + (2, 2, 3, 64, 64, 9, 3), + ] +) +def test_simulated_bn_fold_conv3d(has_bias, batch_size, input_c, output_c, h, w, d, kernel_size, momentum, affine): + distiller.set_deterministic(1234) + conv3d = nn.Conv3d(input_c, output_c, kernel_size, bias=has_bias) + bn = nn.BatchNorm3d(output_c, momentum=momentum, affine=affine) + run_simulated_bn_fold_test(conv3d, bn, (batch_size, input_c, h, w, d), has_bias) + + def run_simulated_bn_fold_test(param_layer, bn_layer, x_size, has_bias): folded = SimulatedFoldedBatchNorm(deepcopy(param_layer), deepcopy(bn_layer), param_quantization_fn=None) unfolded = nn.Sequential(param_layer, bn_layer) @@ -82,13 +129,14 @@ def run_simulated_bn_fold_test(param_layer, bn_layer, x_size, has_bias): loss_unfolded.backward() # check the gradients: - assert_allclose(unfolded[0].weight.grad, folded.param_module.weight.grad) + assert_allclose(unfolded[0].weight.grad, folded.param_module.weight.grad, RTOL, ATOL) if has_bias: # The bias of the linear layer doesn't participate in the calculation! # for more details - refer to `FusedLinearBatchNorm.forward` assert folded.param_module.bias.grad is None - assert_allclose(unfolded[1].weight.grad, folded.bn.weight.grad) - assert_allclose(unfolded[1].bias.grad, folded.bn.bias.grad) + if bn_layer.affine: + assert_allclose(unfolded[1].weight.grad, folded.bn.weight.grad, RTOL, ATOL) + assert_allclose(unfolded[1].bias.grad, folded.bn.bias.grad, RTOL, ATOL) # make a step: optimizer_unfolded.step() @@ -96,8 +144,9 @@ def run_simulated_bn_fold_test(param_layer, bn_layer, x_size, has_bias): # check updated weights (we skip the linear bias) assert_allclose(unfolded[0].weight, folded.param_module.weight, RTOL, ATOL) - assert_allclose(unfolded[1].weight, folded.bn.weight, RTOL, ATOL) - assert_allclose(unfolded[1].bias, folded.bn.bias, RTOL, ATOL) + if bn_layer.affine: + assert_allclose(unfolded[1].weight, folded.bn.weight, RTOL, ATOL) + assert_allclose(unfolded[1].bias, folded.bn.bias, RTOL, ATOL) assert_allclose(unfolded[1].running_mean, folded.bn.running_mean, RTOL, ATOL) assert_allclose(unfolded[1].running_var, folded.bn.running_var, RTOL, ATOL) -- GitLab