Skip to content
Snippets Groups Projects
Commit 9c21c4e3 authored by Neta Zmora's avatar Neta Zmora
Browse files

Plain20 - add a version of the Plain20 model w/o BN layers

parent 8ff74211
No related branches found
No related tags found
No related merge requests found
......@@ -35,7 +35,7 @@ import torch.nn as nn
import math
__all__ = ['plain20_cifar']
__all__ = ['plain20_cifar', 'plain20_cifar_nobn']
NUM_CLASSES = 10
......@@ -49,36 +49,36 @@ def conv3x3(in_planes, out_planes, stride=1):
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1):
def __init__(self, inplanes, planes, stride=1, batch_norm=True):
super().__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.bn1 = nn.BatchNorm2d(planes) if batch_norm else None
self.relu1 = nn.ReLU(inplace=False)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.bn2 = nn.BatchNorm2d(planes) if batch_norm else None
self.relu2 = nn.ReLU(inplace=False)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.bn1(out) if self.bn1 is not None else out
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.bn2(out) if self.bn2 is not None else out
out = self.relu2(out)
return out
class PlainCifar(nn.Module):
def __init__(self, block, blks_per_layer, num_classes=NUM_CLASSES):
def __init__(self, block, blks_per_layer, num_classes=NUM_CLASSES, batch_norm=True):
self.inplanes = 16
super().__init__()
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(self.inplanes)
self.bn1 = nn.BatchNorm2d(self.inplanes) if batch_norm else None
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(block, 16, blks_per_layer[0], stride=1)
self.layer2 = self._make_layer(block, 32, blks_per_layer[1], stride=2)
self.layer3 = self._make_layer(block, 64, blks_per_layer[2], stride=2)
self.layer1 = self._make_layer(block, 16, blks_per_layer[0], stride=1, batch_norm=batch_norm)
self.layer2 = self._make_layer(block, 32, blks_per_layer[1], stride=2, batch_norm=batch_norm)
self.layer3 = self._make_layer(block, 64, blks_per_layer[2], stride=2, batch_norm=batch_norm)
self.avgpool = nn.AvgPool2d(8, stride=1)
self.fc = nn.Linear(64 * block.expansion, num_classes)
......@@ -91,22 +91,22 @@ class PlainCifar(nn.Module):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, num_blocks, stride):
def _make_layer(self, block, planes, num_blocks, stride, batch_norm=True):
# Each layer is composed on 2*num_blocks blocks, and the first block usually
# performs downsampling of the input, and doubling of the number of filters/feature-maps.
blocks = []
inplanes = self.inplanes
# First block is special (downsamples and adds filters)
blocks.append(block(inplanes, planes, stride))
blocks.append(block(inplanes, planes, stride, batch_norm=batch_norm))
self.inplanes = planes * block.expansion
for i in range(num_blocks - 1):
blocks.append(block(self.inplanes, planes, stride=1))
blocks.append(block(self.inplanes, planes, stride=1, batch_norm=batch_norm))
return nn.Sequential(*blocks)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.bn1(x) if self.bn1 is not None else x
x = self.relu(x)
x = self.layer1(x)
......@@ -120,5 +120,13 @@ class PlainCifar(nn.Module):
def plain20_cifar(**kwargs):
# Plain20 for CIFAR10
model = PlainCifar(BasicBlock, [3, 3, 3], **kwargs)
return model
#return plain20_cifar_nobn(**kwargs)
def plain20_cifar_nobn(**kwargs):
# Plain20 for CIFAR10, without batch-normalization layers
model = PlainCifar(BasicBlock, [3, 3, 3], batch_norm=False, **kwargs)
return model
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment