diff --git a/distiller/models/imagenet/resnet.py b/distiller/models/imagenet/resnet.py index 6f0fd7ab11ca35492d2432d309dd9083f0b7c2c2..c5784343241006b5bd45788600cbc927421019c6 100755 --- a/distiller/models/imagenet/resnet.py +++ b/distiller/models/imagenet/resnet.py @@ -31,7 +31,8 @@ from distiller.modules import EltwiseAdd __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', - 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'DistillerBottleneck'] class DistillerBasicBlock(BasicBlock): diff --git a/distiller/models/imagenet/resnet_earlyexit.py b/distiller/models/imagenet/resnet_earlyexit.py index 03fd6c9b165a3f2ed988099374914feae13b7bce..def335c7d3570edab6b3a955361e67213f4e0589 100644 --- a/distiller/models/imagenet/resnet_earlyexit.py +++ b/distiller/models/imagenet/resnet_earlyexit.py @@ -1,9 +1,9 @@ import torch.nn as nn import torchvision.models as models -from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck from .resnet import DistillerBottleneck import distiller + __all__ = ['resnet50_earlyexit'] @@ -15,18 +15,15 @@ def conv3x3(in_planes, out_planes, stride=1): def get_exits_def(num_classes): expansion = 1 # models.ResNet.BasicBlock.expansion - exits_def = [('layer1.2.relu3', nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), - nn.Conv2d(256, 50, kernel_size=7, stride=2, padding=3, bias=True), - nn.Conv2d(50, 12, kernel_size=7, stride=2, padding=3, bias=True), - #nn.AdaptiveAvgPool2d((1, 1)), + exits_def = [('layer1.2.relu3', nn.Sequential(nn.Conv2d(256, 10, kernel_size=7, stride=2, padding=3, bias=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), nn.Flatten(), - nn.Linear(12 * expansion, num_classes))), + nn.Linear(1960, num_classes))), #distiller.modules.Print())), ('layer2.3.relu3', nn.Sequential(nn.Conv2d(512, 12, kernel_size=7, stride=2, padding=3, bias=True), - nn.AdaptiveAvgPool2d((1, 1)), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), nn.Flatten(), - #distiller.modules.Print()))] - nn.Linear(12 * expansion, num_classes)))] + nn.Linear(588, num_classes)))] return exits_def