Skip to content
Snippets Groups Projects
Commit 660a0da5 authored by Neta Zmora's avatar Neta Zmora
Browse files

Refactor ResNetCifarEarlyExit

Step 1 of refactoring EE code in order to make it more generic.
parent bc00ee48
No related branches found
No related tags found
No related merge requests found
...@@ -43,7 +43,7 @@ from .resnet_cifar import ResNetCifar ...@@ -43,7 +43,7 @@ from .resnet_cifar import ResNetCifar
__all__ = ['resnet20_cifar_earlyexit', 'resnet32_cifar_earlyexit', 'resnet44_cifar_earlyexit', __all__ = ['resnet20_cifar_earlyexit', 'resnet32_cifar_earlyexit', 'resnet44_cifar_earlyexit',
'resnet56_cifar_earlyexit', 'resnet110_cifar_earlyexit', 'resnet1202_cifar_earlyexit'] 'resnet56_cifar_earlyexit', 'resnet110_cifar_earlyexit', 'resnet1202_cifar_earlyexit']
NUM_CLASSES = 10 NUM_CLASSES = 10
...@@ -53,38 +53,47 @@ def conv3x3(in_planes, out_planes, stride=1): ...@@ -53,38 +53,47 @@ def conv3x3(in_planes, out_planes, stride=1):
padding=1, bias=False) padding=1, bias=False)
class ResNetCifarEarlyExit(ResNetCifar): class ExitBranch(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.avg_pool = nn.AvgPool2d(3)
self.linear = nn.Linear(1600, num_classes)
def __init__(self, block, layers, num_classes=NUM_CLASSES): def forward(self, x):
super(ResNetCifarEarlyExit, self).__init__(block, layers, num_classes) x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
return x
# Define early exit layers
self.linear_exit0 = nn.Linear(1600, num_classes)
class BranchPoint(nn.Module):
def __init__(self, original_m, exit_m):
super().__init__()
self.original_m = original_m
self.exit_m = exit_m
self.output = None
def forward(self, x): def forward(self, x):
x = self.conv1(x) x1 = self.original_m.forward(x)
x = self.bn1(x) x2 = self.exit_m.forward(x1)
x = self.relu(x) self.output = x2
return x1
x = self.layer1(x)
# Add early exit layers class ResNetCifarEarlyExit(ResNetCifar):
exit0 = nn.functional.avg_pool2d(x, 3) def __init__(self, block, layers, num_classes=NUM_CLASSES):
exit0 = exit0.view(exit0.size(0), -1) super().__init__(block, layers, num_classes)
exit0 = self.linear_exit0(exit0)
x = self.layer2(x) # Define early exit branches and install them
x = self.layer3(x) self.exit_branch = ExitBranch(num_classes)
self.layer1 = BranchPoint(self.layer1, self.exit_branch)
x = self.avgpool(x) def forward(self, x):
x = x.view(x.size(0), -1) x = super().forward(x)
x = self.fc(x)
# return a list of probabilities # return a list of probabilities
output = [] exit0 = self.layer1.output
output.append(exit0) output = (exit0, x)
output.append(x)
return output return output
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment