Skip to content
Snippets Groups Projects
Commit 8a272306 authored by Guy Jacob's avatar Guy Jacob
Browse files

Add support for wide-resnet models that exist in torchvision

parent 10cd1a85
No related branches found
No related tags found
No related merge requests found
...@@ -39,7 +39,7 @@ SUPPORTED_DATASETS = ('imagenet', 'cifar10', 'mnist') ...@@ -39,7 +39,7 @@ SUPPORTED_DATASETS = ('imagenet', 'cifar10', 'mnist')
# ResNet special treatment: we have our own version of ResNet, so we need to over-ride # ResNet special treatment: we have our own version of ResNet, so we need to over-ride
# TorchVision's version. # TorchVision's version.
RESNET_SYMS = ('ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', RESNET_SYMS = ('ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
'resnext50_32x4d', 'resnext101_32x8d') 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2')
TORCHVISION_MODEL_NAMES = sorted( TORCHVISION_MODEL_NAMES = sorted(
name for name in torch_models.__dict__ name for name in torch_models.__dict__
......
...@@ -32,6 +32,7 @@ from distiller.modules import EltwiseAdd ...@@ -32,6 +32,7 @@ from distiller.modules import EltwiseAdd
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
'wide_resnet50_2', 'wide_resnet101_2',
'DistillerBottleneck'] 'DistillerBottleneck']
...@@ -208,3 +209,26 @@ def resnext101_32x8d(pretrained=False, progress=True, **kwargs): ...@@ -208,3 +209,26 @@ def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
kwargs['width_per_group'] = 8 kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', DistillerBottleneck, [3, 4, 23, 3], return _resnet('resnext101_32x8d', DistillerBottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs) pretrained, progress, **kwargs)
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
"""Constructs a Wide ResNet-50-2 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet50_2', DistillerBottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
"""Constructs a Wide ResNet-101-2 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet101_2', DistillerBottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment