diff --git a/distiller/models/cifar10/plain_cifar.py b/distiller/models/cifar10/plain_cifar.py index c6a701980afb3096f071dcecc4e68cd126b68453..7f668a5e346d019beb171ff6b2fdbb3e4158053f 100755 --- a/distiller/models/cifar10/plain_cifar.py +++ b/distiller/models/cifar10/plain_cifar.py @@ -35,7 +35,7 @@ import torch.nn as nn import math -__all__ = ['plain20_cifar'] +__all__ = ['plain20_cifar', 'plain20_cifar_nobn'] NUM_CLASSES = 10 @@ -49,36 +49,36 @@ def conv3x3(in_planes, out_planes, stride=1): class BasicBlock(nn.Module): expansion = 1 - def __init__(self, inplanes, planes, stride=1): + def __init__(self, inplanes, planes, stride=1, batch_norm=True): super().__init__() self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = nn.BatchNorm2d(planes) + self.bn1 = nn.BatchNorm2d(planes) if batch_norm else None self.relu1 = nn.ReLU(inplace=False) self.conv2 = conv3x3(planes, planes) - self.bn2 = nn.BatchNorm2d(planes) + self.bn2 = nn.BatchNorm2d(planes) if batch_norm else None self.relu2 = nn.ReLU(inplace=False) def forward(self, x): out = self.conv1(x) - out = self.bn1(out) + out = self.bn1(out) if self.bn1 is not None else out out = self.relu1(out) out = self.conv2(out) - out = self.bn2(out) + out = self.bn2(out) if self.bn2 is not None else out out = self.relu2(out) return out class PlainCifar(nn.Module): - def __init__(self, block, blks_per_layer, num_classes=NUM_CLASSES): + def __init__(self, block, blks_per_layer, num_classes=NUM_CLASSES, batch_norm=True): self.inplanes = 16 super().__init__() self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(self.inplanes) + self.bn1 = nn.BatchNorm2d(self.inplanes) if batch_norm else None self.relu = nn.ReLU(inplace=True) - self.layer1 = self._make_layer(block, 16, blks_per_layer[0], stride=1) - self.layer2 = self._make_layer(block, 32, blks_per_layer[1], stride=2) - self.layer3 = self._make_layer(block, 64, blks_per_layer[2], stride=2) + self.layer1 = self._make_layer(block, 16, blks_per_layer[0], stride=1, batch_norm=batch_norm) + self.layer2 = self._make_layer(block, 32, blks_per_layer[1], stride=2, batch_norm=batch_norm) + self.layer3 = self._make_layer(block, 64, blks_per_layer[2], stride=2, batch_norm=batch_norm) self.avgpool = nn.AvgPool2d(8, stride=1) self.fc = nn.Linear(64 * block.expansion, num_classes) @@ -91,22 +91,22 @@ class PlainCifar(nn.Module): m.weight.data.fill_(1) m.bias.data.zero_() - def _make_layer(self, block, planes, num_blocks, stride): + def _make_layer(self, block, planes, num_blocks, stride, batch_norm=True): # Each layer is composed on 2*num_blocks blocks, and the first block usually # performs downsampling of the input, and doubling of the number of filters/feature-maps. blocks = [] inplanes = self.inplanes # First block is special (downsamples and adds filters) - blocks.append(block(inplanes, planes, stride)) + blocks.append(block(inplanes, planes, stride, batch_norm=batch_norm)) self.inplanes = planes * block.expansion for i in range(num_blocks - 1): - blocks.append(block(self.inplanes, planes, stride=1)) + blocks.append(block(self.inplanes, planes, stride=1, batch_norm=batch_norm)) return nn.Sequential(*blocks) def forward(self, x): x = self.conv1(x) - x = self.bn1(x) + x = self.bn1(x) if self.bn1 is not None else x x = self.relu(x) x = self.layer1(x) @@ -120,5 +120,13 @@ class PlainCifar(nn.Module): def plain20_cifar(**kwargs): + # Plain20 for CIFAR10 model = PlainCifar(BasicBlock, [3, 3, 3], **kwargs) return model + #return plain20_cifar_nobn(**kwargs) + + +def plain20_cifar_nobn(**kwargs): + # Plain20 for CIFAR10, without batch-normalization layers + model = PlainCifar(BasicBlock, [3, 3, 3], batch_norm=False, **kwargs) + return model \ No newline at end of file