From 8a2723063273d5ecc2e4d0fe2a1fc84c068ec242 Mon Sep 17 00:00:00 2001 From: Guy Jacob <guy.jacob@intel.com> Date: Wed, 11 Dec 2019 10:55:32 +0200 Subject: [PATCH] Add support for wide-resnet models that exist in torchvision --- distiller/models/__init__.py | 2 +- distiller/models/imagenet/resnet.py | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py index 366edd4..6ac57c8 100755 --- a/distiller/models/__init__.py +++ b/distiller/models/__init__.py @@ -39,7 +39,7 @@ SUPPORTED_DATASETS = ('imagenet', 'cifar10', 'mnist') # ResNet special treatment: we have our own version of ResNet, so we need to over-ride # TorchVision's version. RESNET_SYMS = ('ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', - 'resnext50_32x4d', 'resnext101_32x8d') + 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2') TORCHVISION_MODEL_NAMES = sorted( name for name in torch_models.__dict__ diff --git a/distiller/models/imagenet/resnet.py b/distiller/models/imagenet/resnet.py index c578434..0d497c2 100755 --- a/distiller/models/imagenet/resnet.py +++ b/distiller/models/imagenet/resnet.py @@ -32,6 +32,7 @@ from distiller.modules import EltwiseAdd __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2', 'DistillerBottleneck'] @@ -208,3 +209,26 @@ def resnext101_32x8d(pretrained=False, progress=True, **kwargs): kwargs['width_per_group'] = 8 return _resnet('resnext101_32x8d', DistillerBottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) + +def wide_resnet50_2(pretrained=False, progress=True, **kwargs): + """Constructs a Wide ResNet-50-2 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', DistillerBottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet101_2(pretrained=False, progress=True, **kwargs): + """Constructs a Wide ResNet-101-2 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', DistillerBottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) -- GitLab