diff --git a/distiller/models/cifar10/resnet_cifar_earlyexit.py b/distiller/models/cifar10/resnet_cifar_earlyexit.py index c4f75d7288d93ff1efcc65cc7bde0332718923e7..26b67c5bbb91002602c8604ec858d6f95124d29d 100644 --- a/distiller/models/cifar10/resnet_cifar_earlyexit.py +++ b/distiller/models/cifar10/resnet_cifar_earlyexit.py @@ -43,7 +43,7 @@ from .resnet_cifar import ResNetCifar __all__ = ['resnet20_cifar_earlyexit', 'resnet32_cifar_earlyexit', 'resnet44_cifar_earlyexit', - 'resnet56_cifar_earlyexit', 'resnet110_cifar_earlyexit', 'resnet1202_cifar_earlyexit'] + 'resnet56_cifar_earlyexit', 'resnet110_cifar_earlyexit', 'resnet1202_cifar_earlyexit'] NUM_CLASSES = 10 @@ -53,38 +53,47 @@ def conv3x3(in_planes, out_planes, stride=1): padding=1, bias=False) -class ResNetCifarEarlyExit(ResNetCifar): +class ExitBranch(nn.Module): + def __init__(self, num_classes): + super().__init__() + self.avg_pool = nn.AvgPool2d(3) + self.linear = nn.Linear(1600, num_classes) - def __init__(self, block, layers, num_classes=NUM_CLASSES): - super(ResNetCifarEarlyExit, self).__init__(block, layers, num_classes) + def forward(self, x): + x = self.avg_pool(x) + x = x.view(x.size(0), -1) + x = self.linear(x) + return x - # Define early exit layers - self.linear_exit0 = nn.Linear(1600, num_classes) +class BranchPoint(nn.Module): + def __init__(self, original_m, exit_m): + super().__init__() + self.original_m = original_m + self.exit_m = exit_m + self.output = None def forward(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) + x1 = self.original_m.forward(x) + x2 = self.exit_m.forward(x1) + self.output = x2 + return x1 - x = self.layer1(x) - # Add early exit layers - exit0 = nn.functional.avg_pool2d(x, 3) - exit0 = exit0.view(exit0.size(0), -1) - exit0 = self.linear_exit0(exit0) +class ResNetCifarEarlyExit(ResNetCifar): + def __init__(self, block, layers, num_classes=NUM_CLASSES): + super().__init__(block, layers, num_classes) - x = self.layer2(x) - x = self.layer3(x) + # Define early exit branches and install them + self.exit_branch = ExitBranch(num_classes) + self.layer1 = BranchPoint(self.layer1, self.exit_branch) - x = self.avgpool(x) - x = x.view(x.size(0), -1) - x = self.fc(x) + def forward(self, x): + x = super().forward(x) # return a list of probabilities - output = [] - output.append(exit0) - output.append(x) + exit0 = self.layer1.output + output = (exit0, x) return output