diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py index 366edd4705aa0209a4e885e996108726bceb203d..6ac57c89b7659a5c83a9a29ed6f0f39c8cae7bff 100755 --- a/distiller/models/__init__.py +++ b/distiller/models/__init__.py @@ -39,7 +39,7 @@ SUPPORTED_DATASETS = ('imagenet', 'cifar10', 'mnist') # ResNet special treatment: we have our own version of ResNet, so we need to over-ride # TorchVision's version. RESNET_SYMS = ('ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', - 'resnext50_32x4d', 'resnext101_32x8d') + 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2') TORCHVISION_MODEL_NAMES = sorted( name for name in torch_models.__dict__ diff --git a/distiller/models/imagenet/resnet.py b/distiller/models/imagenet/resnet.py index c5784343241006b5bd45788600cbc927421019c6..0d497c2f240654980da4163a0abed285a90451af 100755 --- a/distiller/models/imagenet/resnet.py +++ b/distiller/models/imagenet/resnet.py @@ -32,6 +32,7 @@ from distiller.modules import EltwiseAdd __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2', 'DistillerBottleneck'] @@ -208,3 +209,26 @@ def resnext101_32x8d(pretrained=False, progress=True, **kwargs): kwargs['width_per_group'] = 8 return _resnet('resnext101_32x8d', DistillerBottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) + +def wide_resnet50_2(pretrained=False, progress=True, **kwargs): + """Constructs a Wide ResNet-50-2 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', DistillerBottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet101_2(pretrained=False, progress=True, **kwargs): + """Constructs a Wide ResNet-101-2 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', DistillerBottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs)