From 8a2723063273d5ecc2e4d0fe2a1fc84c068ec242 Mon Sep 17 00:00:00 2001
From: Guy Jacob <guy.jacob@intel.com>
Date: Wed, 11 Dec 2019 10:55:32 +0200
Subject: [PATCH] Add support for wide-resnet models that exist in torchvision

---
 distiller/models/__init__.py        |  2 +-
 distiller/models/imagenet/resnet.py | 24 ++++++++++++++++++++++++
 2 files changed, 25 insertions(+), 1 deletion(-)

diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py
index 366edd4..6ac57c8 100755
--- a/distiller/models/__init__.py
+++ b/distiller/models/__init__.py
@@ -39,7 +39,7 @@ SUPPORTED_DATASETS = ('imagenet', 'cifar10', 'mnist')
 # ResNet special treatment: we have our own version of ResNet, so we need to over-ride
 # TorchVision's version.
 RESNET_SYMS = ('ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
-               'resnext50_32x4d', 'resnext101_32x8d')
+               'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2')
 
 TORCHVISION_MODEL_NAMES = sorted(
                             name for name in torch_models.__dict__
diff --git a/distiller/models/imagenet/resnet.py b/distiller/models/imagenet/resnet.py
index c578434..0d497c2 100755
--- a/distiller/models/imagenet/resnet.py
+++ b/distiller/models/imagenet/resnet.py
@@ -32,6 +32,7 @@ from distiller.modules import EltwiseAdd
 
 __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
            'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
+           'wide_resnet50_2', 'wide_resnet101_2',
            'DistillerBottleneck']
 
 
@@ -208,3 +209,26 @@ def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
     kwargs['width_per_group'] = 8
     return _resnet('resnext101_32x8d', DistillerBottleneck, [3, 4, 23, 3],
                    pretrained, progress, **kwargs)
+
+def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
+    """Constructs a Wide ResNet-50-2 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    kwargs['width_per_group'] = 64 * 2
+    return _resnet('wide_resnet50_2', DistillerBottleneck, [3, 4, 6, 3],
+                   pretrained, progress, **kwargs)
+
+
+def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
+    """Constructs a Wide ResNet-101-2 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    kwargs['width_per_group'] = 64 * 2
+    return _resnet('wide_resnet101_2', DistillerBottleneck, [3, 4, 23, 3],
+                   pretrained, progress, **kwargs)
-- 
GitLab