Skip to content
Snippets Groups Projects
Commit 84240654 authored by Guy Jacob's avatar Guy Jacob
Browse files

Simulated BN fold module changes

* Support case where BN module has no learnable parameters
  (affine == False)
* Support conv1d and conv3d
parent b60a33ef
No related branches found
No related tags found
No related merge requests found
...@@ -8,6 +8,11 @@ __all__ = ['SimulatedFoldedBatchNorm'] ...@@ -8,6 +8,11 @@ __all__ = ['SimulatedFoldedBatchNorm']
FREEZE_BN_DELAY_DEFAULT = 200000 FREEZE_BN_DELAY_DEFAULT = 200000
_conv_meta = {'conv1d': (1, F.conv1d),
'conv2d': (2, F.conv2d),
'conv3d': (3, F.conv3d)}
def _broadcast_correction_factor(c, broadcast_to_shape): def _broadcast_correction_factor(c, broadcast_to_shape):
""" """
Returns a view of `c` which is broadcastable with shape `broadcast_to_shape`. Returns a view of `c` which is broadcastable with shape `broadcast_to_shape`.
...@@ -22,17 +27,16 @@ class SimulatedFoldedBatchNorm(nn.Module): ...@@ -22,17 +27,16 @@ class SimulatedFoldedBatchNorm(nn.Module):
""" """
Wrapper for simulated folding of BatchNorm into convolution / linear layers during training Wrapper for simulated folding of BatchNorm into convolution / linear layers during training
Args: Args:
param_module (nn.Linear or nn.Conv2d): the wrapped parameter layer param_module (nn.Linear or nn.Conv1d or nn.Conv2d or nn.Conv3d): the wrapped parameter module
bn (nn.BatchNorm1d or nn.BatchNorm2d): batch normalization bn (nn.BatchNorm1d or nn.BatchNorm2d or nn.BatchNorm3d): batch normalization module
freeze_bn_delay (int): number of steps before freezing the batchnorm running stats freeze_bn_delay (int): number of steps before freezing the batch-norm running stats
param_quantization_fn (function): function to be used for weight/bias quantization param_quantization_fn (function): function to be used for weight/bias quantization
Note: Note:
The quantized version was implemented according to https://arxiv.org/pdf/1806.08342.pdf Section 3.2.2. The quantized version was implemented according to https://arxiv.org/pdf/1806.08342.pdf Section 3.2.2.
""" """
SimulatedFoldedBatchNorm.verify_module_types(param_module, bn) SimulatedFoldedBatchNorm.verify_module_types(param_module, bn)
if not bn.track_running_stats or not bn.affine: if not bn.track_running_stats:
raise ValueError("Simulated BN folding is only supported for BatchNorm which tracks running stats with" raise ValueError("Simulated BN folding is only supported for BatchNorm which tracks running stats")
"affine weights.")
super(SimulatedFoldedBatchNorm, self).__init__() super(SimulatedFoldedBatchNorm, self).__init__()
self.param_module = param_module self.param_module = param_module
self.bn = bn self.bn = bn
...@@ -43,19 +47,30 @@ class SimulatedFoldedBatchNorm(nn.Module): ...@@ -43,19 +47,30 @@ class SimulatedFoldedBatchNorm(nn.Module):
if isinstance(param_module, nn.Linear): if isinstance(param_module, nn.Linear):
self.param_forward_fn = self._linear_layer_forward self.param_forward_fn = self._linear_layer_forward
self.param_module_type = "fc" self.param_module_type = "fc"
else: elif isinstance(param_module, nn.Conv1d):
self.param_forward_fn = self._conv2d_layer_forward self.param_forward_fn = self._conv_layer_forward
self.param_module_type = "conv1d"
elif isinstance(param_module, nn.Conv2d):
self.param_forward_fn = self._conv_layer_forward
self.param_module_type = "conv2d" self.param_module_type = "conv2d"
else:
self.param_forward_fn = self._conv_layer_forward
self.param_module_type = "conv3d"
@staticmethod @staticmethod
def verify_module_types(param_module, bn): def verify_module_types(param_module, bn):
if not isinstance(param_module, (nn.Linear, nn.Conv2d)) \ foldable_seqs = [((nn.Linear, nn.Conv1d), nn.BatchNorm1d),
and not isinstance(bn, (nn.BatchNorm1d, nn.BatchNorm2d)): (nn.Conv2d, nn.BatchNorm2d),
raise TypeError("Only supporting fusing nn.BatchNorm1d/nn.BatchNorm2d into nn.Linear/nn.Conv2d.") (nn.Conv3d, nn.BatchNorm3d)]
if isinstance(param_module, nn.Linear) and isinstance(bn, nn.BatchNorm2d): error_msg = "Can't fold sequence of {} --> {}. ".format(param_module.__class__.__name__, bn.__class__.__name__)
raise TypeError("nn.Linear layer has to be followed by a nn.BatchNorm1d layer.") for seq in foldable_seqs:
if isinstance(param_module, nn.Conv2d) and isinstance(bn, nn.BatchNorm1d): if isinstance(param_module, seq[0]):
raise TypeError("nn.Con2d layer has to be followed by a nn.BatchNorm2d layer.") if not isinstance(bn, seq[1]):
raise TypeError(error_msg + "{} must be followed by {}".
format(param_module.__class__.__name__, seq[1].__name__))
return
raise TypeError(error_msg + "Only Conv/Linear modules followed by BatchNorm modules allowed"
.format(param_module.__class__.__name__, bn.__class__.__name__))
def forward(self, x): def forward(self, x):
""" """
...@@ -71,7 +86,7 @@ class SimulatedFoldedBatchNorm(nn.Module): ...@@ -71,7 +86,7 @@ class SimulatedFoldedBatchNorm(nn.Module):
= (x*W -E(x*W)) * gamma / std(x*W) + beta = (x*W -E(x*W)) * gamma / std(x*W) + beta
""" """
if not self.frozen: if not self.frozen:
w, b, gamma, beta = self.param_module.weight, self.param_module.bias, self.bn.weight, self.bn.bias w, b, gamma, beta = self._get_all_parameters()
if self.training: if self.training:
batch_mean, batch_var = self.batch_stats(self.param_forward_fn(x, w), b) batch_mean, batch_var = self.batch_stats(self.param_forward_fn(x, w), b)
recip_sigma_batch = torch.rsqrt(batch_var + self.bn.eps) recip_sigma_batch = torch.rsqrt(batch_var + self.bn.eps)
...@@ -104,7 +119,7 @@ class SimulatedFoldedBatchNorm(nn.Module): ...@@ -104,7 +119,7 @@ class SimulatedFoldedBatchNorm(nn.Module):
""" """
Broadcasts a correction factor to the output for elementwise operations. Broadcasts a correction factor to the output for elementwise operations.
""" """
expected_output_dim = 2 if self.param_module_type == "fc" else 4 expected_output_dim = 2 if self.param_module_type == "fc" else _conv_meta[self.param_module_type][0] + 2
view_fillers_dim = expected_output_dim - c.dim() - 1 view_fillers_dim = expected_output_dim - c.dim() - 1
view_filler = (1,) * view_fillers_dim view_filler = (1,) * view_fillers_dim
expected_view_shape = c.shape + view_filler expected_view_shape = c.shape + view_filler
...@@ -116,7 +131,7 @@ class SimulatedFoldedBatchNorm(nn.Module): ...@@ -116,7 +131,7 @@ class SimulatedFoldedBatchNorm(nn.Module):
""" """
if c.dim() != 1: if c.dim() != 1:
raise ValueError("Correction factor needs to have a single dimension") raise ValueError("Correction factor needs to have a single dimension")
expected_weight_dim = 2 if self.param_module_type == "fc" else 4 expected_weight_dim = 2 if self.param_module_type == "fc" else _conv_meta[self.param_module_type][0] + 2
view_fillers_dim = expected_weight_dim - c.dim() view_fillers_dim = expected_weight_dim - c.dim()
view_filler = (1,) * view_fillers_dim view_filler = (1,) * view_fillers_dim
expected_view_shape = c.shape + view_filler expected_view_shape = c.shape + view_filler
...@@ -185,20 +200,24 @@ class SimulatedFoldedBatchNorm(nn.Module): ...@@ -185,20 +200,24 @@ class SimulatedFoldedBatchNorm(nn.Module):
def _linear_layer_forward(self, input, w, b=None): def _linear_layer_forward(self, input, w, b=None):
return F.linear(input, w, b) return F.linear(input, w, b)
def _conv2d_layer_forward(self, input, w, b=None): def _conv_layer_forward(self, input, w, b=None):
# We copy the code from the Conv2d forward, but plug in our weights. # We implement according to Conv1/2/3d.forward(), but plug in our weights
conv = self.param_module # type: nn.Conv2d conv = self.param_module
if conv.__dict__.get('padding_mode', None) == 'circular': # This attribute doesn't exist yet in pytorch 1.0.1 ndims, func = _conv_meta[self.param_module_type]
expanded_padding = [(conv.padding[1] + 1) // 2, conv.padding[1] // 2,
(conv.padding[0] + 1) // 2, conv.padding[0] // 2] # 'circular' padding doesn't exist pre-pytorch 1.1.0
return F.conv2d(F.pad(input, expanded_padding, mode='circular'), if getattr(conv, 'padding_mode', None) == 'circular':
w, b, conv.stride, expanded_padding = []
(0, 0), conv.dilation, conv.groups) for pad_idx in reversed(range(ndims)):
return F.conv2d(input, w, b, conv.stride, expanded_padding.extend([(conv.padding[pad_idx] + 1) // 2, conv.padding[pad_idx] // 2])
conv.padding, conv.dilation, conv.groups) return func(F.pad(input, expanded_padding, mode='circular'),
w, b, conv.stride,
(0,) * ndims, conv.dilation, conv.groups)
return func(input, w, b, conv.stride,
conv.padding, conv.dilation, conv.groups)
def freeze(self): def freeze(self):
w, b, gamma, beta = self.param_module.weight, self.param_module.bias, self.bn.weight, self.bn.bias w, b, gamma, beta = self._get_all_parameters()
with torch.no_grad(): with torch.no_grad():
recip_sigma_running = torch.rsqrt(self.bn.running_var + self.bn.eps) recip_sigma_running = torch.rsqrt(self.bn.running_var + self.bn.eps)
w.mul_(self.broadcast_correction_weight(gamma * recip_sigma_running)) w.mul_(self.broadcast_correction_weight(gamma * recip_sigma_running))
...@@ -209,3 +228,10 @@ class SimulatedFoldedBatchNorm(nn.Module): ...@@ -209,3 +228,10 @@ class SimulatedFoldedBatchNorm(nn.Module):
else: else:
self.param_module.bias = nn.Parameter(bias_corrected) self.param_module.bias = nn.Parameter(bias_corrected)
self.frozen = True self.frozen = True
def _get_all_parameters(self):
w, b, gamma, beta = self.param_module.weight, self.param_module.bias, self.bn.weight, self.bn.bias
if not self.bn.affine:
gamma = 1.
beta = 0.
return w, b, gamma, beta
...@@ -14,6 +14,22 @@ LR = 1e-3 ...@@ -14,6 +14,22 @@ LR = 1e-3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
@pytest.mark.parametrize(
"m1, m2",
[
(nn.ReLU(), nn.BatchNorm1d(5)),
(nn.Conv1d(1, 2, 3), nn.ReLU()),
(nn.Conv1d(1, 2, 3), nn.BatchNorm2d(2)),
(nn.Conv2d(1, 2, 3), nn.BatchNorm3d(2)),
(nn.Conv3d(1, 2, 3), nn.BatchNorm2d(2)),
(nn.Linear(3, 5), nn.BatchNorm2d(5))
]
)
def test_simulated_bn_fold_bad_sequences(m1, m2):
with pytest.raises(TypeError):
SimulatedFoldedBatchNorm(m1, m2)
@pytest.fixture(params=[False, True], ids=['bias_off', 'bias_on']) @pytest.fixture(params=[False, True], ids=['bias_off', 'bias_on'])
def has_bias(request): def has_bias(request):
return request.param return request.param
...@@ -24,6 +40,11 @@ def momentum(request): ...@@ -24,6 +40,11 @@ def momentum(request):
return request.param return request.param
@pytest.fixture(params=[True, False], ids=['affine_on', 'affine_off'])
def affine(request):
return request.param
@pytest.mark.parametrize( @pytest.mark.parametrize(
"batch_size, input_size, output_size", "batch_size, input_size, output_size",
[ [
...@@ -32,12 +53,25 @@ def momentum(request): ...@@ -32,12 +53,25 @@ def momentum(request):
(256, 128, 1024) (256, 128, 1024)
] ]
) )
def test_simulated_bn_fold_fc(has_bias, batch_size, input_size, output_size, momentum): def test_simulated_bn_fold_fc(has_bias, batch_size, input_size, output_size, momentum, affine):
distiller.set_deterministic(1234) distiller.set_deterministic(1234)
linear = nn.Linear(input_size, output_size, bias=has_bias) linear = nn.Linear(input_size, output_size, bias=has_bias)
bn = nn.BatchNorm1d(output_size, momentum=momentum) bn = nn.BatchNorm1d(output_size, momentum=momentum, affine=affine)
run_simulated_bn_fold_test(linear, bn, (batch_size, input_size), has_bias) run_simulated_bn_fold_test(linear, bn, (batch_size, input_size), has_bias)
@pytest.mark.parametrize(
"batch_size, input_c, output_c, l, kernel_size",
[
(50, 3, 100, 80, 10),
]
)
def test_simulated_bn_fold_conv1d(has_bias, batch_size, input_c, output_c, l, kernel_size, momentum, affine):
distiller.set_deterministic(1234)
conv1d = nn.Conv1d(input_c, output_c, kernel_size, bias=has_bias)
bn = nn.BatchNorm1d(output_c, momentum=momentum, affine=affine)
run_simulated_bn_fold_test(conv1d, bn, (batch_size, input_c, l), has_bias)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"batch_size, input_c, output_c, h, w, kernel_size", "batch_size, input_c, output_c, h, w, kernel_size",
...@@ -47,13 +81,26 @@ def test_simulated_bn_fold_fc(has_bias, batch_size, input_size, output_size, mom ...@@ -47,13 +81,26 @@ def test_simulated_bn_fold_fc(has_bias, batch_size, input_size, output_size, mom
(256, 3, 64, 28, 28, 7), (256, 3, 64, 28, 28, 7),
] ]
) )
def test_simulated_bn_fold_conv(has_bias, batch_size, input_c, output_c, h, w, kernel_size, momentum): def test_simulated_bn_fold_conv2d(has_bias, batch_size, input_c, output_c, h, w, kernel_size, momentum, affine):
distiller.set_deterministic(1234) distiller.set_deterministic(1234)
conv2d = nn.Conv2d(input_c, output_c, kernel_size, bias=has_bias) conv2d = nn.Conv2d(input_c, output_c, kernel_size, bias=has_bias)
bn = nn.BatchNorm2d(output_c, momentum=momentum) bn = nn.BatchNorm2d(output_c, momentum=momentum, affine=affine)
run_simulated_bn_fold_test(conv2d, bn, (batch_size, input_c, h, w), has_bias) run_simulated_bn_fold_test(conv2d, bn, (batch_size, input_c, h, w), has_bias)
@pytest.mark.parametrize(
"batch_size, input_c, output_c, h, w, d, kernel_size",
[
(2, 2, 3, 64, 64, 9, 3),
]
)
def test_simulated_bn_fold_conv3d(has_bias, batch_size, input_c, output_c, h, w, d, kernel_size, momentum, affine):
distiller.set_deterministic(1234)
conv3d = nn.Conv3d(input_c, output_c, kernel_size, bias=has_bias)
bn = nn.BatchNorm3d(output_c, momentum=momentum, affine=affine)
run_simulated_bn_fold_test(conv3d, bn, (batch_size, input_c, h, w, d), has_bias)
def run_simulated_bn_fold_test(param_layer, bn_layer, x_size, has_bias): def run_simulated_bn_fold_test(param_layer, bn_layer, x_size, has_bias):
folded = SimulatedFoldedBatchNorm(deepcopy(param_layer), deepcopy(bn_layer), param_quantization_fn=None) folded = SimulatedFoldedBatchNorm(deepcopy(param_layer), deepcopy(bn_layer), param_quantization_fn=None)
unfolded = nn.Sequential(param_layer, bn_layer) unfolded = nn.Sequential(param_layer, bn_layer)
...@@ -82,13 +129,14 @@ def run_simulated_bn_fold_test(param_layer, bn_layer, x_size, has_bias): ...@@ -82,13 +129,14 @@ def run_simulated_bn_fold_test(param_layer, bn_layer, x_size, has_bias):
loss_unfolded.backward() loss_unfolded.backward()
# check the gradients: # check the gradients:
assert_allclose(unfolded[0].weight.grad, folded.param_module.weight.grad) assert_allclose(unfolded[0].weight.grad, folded.param_module.weight.grad, RTOL, ATOL)
if has_bias: if has_bias:
# The bias of the linear layer doesn't participate in the calculation! # The bias of the linear layer doesn't participate in the calculation!
# for more details - refer to `FusedLinearBatchNorm.forward` # for more details - refer to `FusedLinearBatchNorm.forward`
assert folded.param_module.bias.grad is None assert folded.param_module.bias.grad is None
assert_allclose(unfolded[1].weight.grad, folded.bn.weight.grad) if bn_layer.affine:
assert_allclose(unfolded[1].bias.grad, folded.bn.bias.grad) assert_allclose(unfolded[1].weight.grad, folded.bn.weight.grad, RTOL, ATOL)
assert_allclose(unfolded[1].bias.grad, folded.bn.bias.grad, RTOL, ATOL)
# make a step: # make a step:
optimizer_unfolded.step() optimizer_unfolded.step()
...@@ -96,8 +144,9 @@ def run_simulated_bn_fold_test(param_layer, bn_layer, x_size, has_bias): ...@@ -96,8 +144,9 @@ def run_simulated_bn_fold_test(param_layer, bn_layer, x_size, has_bias):
# check updated weights (we skip the linear bias) # check updated weights (we skip the linear bias)
assert_allclose(unfolded[0].weight, folded.param_module.weight, RTOL, ATOL) assert_allclose(unfolded[0].weight, folded.param_module.weight, RTOL, ATOL)
assert_allclose(unfolded[1].weight, folded.bn.weight, RTOL, ATOL) if bn_layer.affine:
assert_allclose(unfolded[1].bias, folded.bn.bias, RTOL, ATOL) assert_allclose(unfolded[1].weight, folded.bn.weight, RTOL, ATOL)
assert_allclose(unfolded[1].bias, folded.bn.bias, RTOL, ATOL)
assert_allclose(unfolded[1].running_mean, folded.bn.running_mean, RTOL, ATOL) assert_allclose(unfolded[1].running_mean, folded.bn.running_mean, RTOL, ATOL)
assert_allclose(unfolded[1].running_var, folded.bn.running_var, RTOL, ATOL) assert_allclose(unfolded[1].running_var, folded.bn.running_var, RTOL, ATOL)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment