From 8b341593e8fe71919de149cdd00e269061cecaed Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Mon, 25 Nov 2019 16:58:15 +0200 Subject: [PATCH] Resnet50 early-exit update Update the definition of the exits using info from Haim. This is still very unsatsifactory because we don't have working examples to show users :-( --- distiller/models/imagenet/resnet.py | 3 ++- distiller/models/imagenet/resnet_earlyexit.py | 15 ++++++--------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/distiller/models/imagenet/resnet.py b/distiller/models/imagenet/resnet.py index 6f0fd7a..c578434 100755 --- a/distiller/models/imagenet/resnet.py +++ b/distiller/models/imagenet/resnet.py @@ -31,7 +31,8 @@ from distiller.modules import EltwiseAdd __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', - 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'DistillerBottleneck'] class DistillerBasicBlock(BasicBlock): diff --git a/distiller/models/imagenet/resnet_earlyexit.py b/distiller/models/imagenet/resnet_earlyexit.py index 03fd6c9..def335c 100644 --- a/distiller/models/imagenet/resnet_earlyexit.py +++ b/distiller/models/imagenet/resnet_earlyexit.py @@ -1,9 +1,9 @@ import torch.nn as nn import torchvision.models as models -from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck from .resnet import DistillerBottleneck import distiller + __all__ = ['resnet50_earlyexit'] @@ -15,18 +15,15 @@ def conv3x3(in_planes, out_planes, stride=1): def get_exits_def(num_classes): expansion = 1 # models.ResNet.BasicBlock.expansion - exits_def = [('layer1.2.relu3', nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), - nn.Conv2d(256, 50, kernel_size=7, stride=2, padding=3, bias=True), - nn.Conv2d(50, 12, kernel_size=7, stride=2, padding=3, bias=True), - #nn.AdaptiveAvgPool2d((1, 1)), + exits_def = [('layer1.2.relu3', nn.Sequential(nn.Conv2d(256, 10, kernel_size=7, stride=2, padding=3, bias=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), nn.Flatten(), - nn.Linear(12 * expansion, num_classes))), + nn.Linear(1960, num_classes))), #distiller.modules.Print())), ('layer2.3.relu3', nn.Sequential(nn.Conv2d(512, 12, kernel_size=7, stride=2, padding=3, bias=True), - nn.AdaptiveAvgPool2d((1, 1)), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), nn.Flatten(), - #distiller.modules.Print()))] - nn.Linear(12 * expansion, num_classes)))] + nn.Linear(588, num_classes)))] return exits_def -- GitLab